wasm_pkg_common/
digest.rs1use std::str::FromStr;
2
3use bytes::Bytes;
4use futures_util::{future::ready, stream::once, Stream, StreamExt, TryStream, TryStreamExt};
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7
8use crate::Error;
9
10#[derive(Clone, Debug, PartialEq, Eq)]
12pub enum ContentDigest {
13 Sha256 { hex: String },
14}
15
16impl ContentDigest {
17 pub fn validating_stream(
18 &self,
19 stream: impl TryStream<Ok = Bytes, Error = Error>,
20 ) -> impl Stream<Item = Result<Bytes, Error>> {
21 let want = self.clone();
22 stream.map_ok(Some).chain(once(async { Ok(None) })).scan(
23 Sha256::new(),
24 move |hasher, res| {
25 ready(match res {
26 Ok(Some(bytes)) => {
27 hasher.update(&bytes);
28 Some(Ok(bytes))
29 }
30 Ok(None) => {
31 let got: Self = std::mem::take(hasher).into();
32 if got == want {
33 None
34 } else {
35 Some(Err(Error::InvalidContent(format!(
36 "expected digest {want}, got {got}"
37 ))))
38 }
39 }
40 Err(err) => Some(Err(err)),
41 })
42 },
43 )
44 }
45}
46
47impl std::fmt::Display for ContentDigest {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 ContentDigest::Sha256 { hex } => write!(f, "sha256:{hex}"),
51 }
52 }
53}
54
55impl From<Sha256> for ContentDigest {
56 fn from(hasher: Sha256) -> Self {
57 Self::Sha256 {
58 hex: format!("{:x}", hasher.finalize()),
59 }
60 }
61}
62
63impl<'a> TryFrom<&'a str> for ContentDigest {
64 type Error = Error;
65
66 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
67 let Some(hex) = value.strip_prefix("sha256:") else {
68 return Err(Error::InvalidContentDigest(
69 "must start with 'sha256:'".into(),
70 ));
71 };
72 let hex = hex.to_lowercase();
73 if hex.len() != 64 {
74 return Err(Error::InvalidContentDigest(format!(
75 "must be 64 hex digits; got {} chars",
76 hex.len()
77 )));
78 }
79 if let Some(invalid) = hex.chars().find(|c| !c.is_ascii_hexdigit()) {
80 return Err(Error::InvalidContentDigest(format!(
81 "must be hex; got {invalid:?}"
82 )));
83 }
84 Ok(Self::Sha256 { hex })
85 }
86}
87
88impl FromStr for ContentDigest {
89 type Err = Error;
90
91 fn from_str(s: &str) -> Result<Self, Self::Err> {
92 s.try_into()
93 }
94}
95
96impl Serialize for ContentDigest {
97 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
98 serializer.serialize_str(&self.to_string())
99 }
100}
101
102impl<'de> Deserialize<'de> for ContentDigest {
103 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
104 where
105 D: serde::Deserializer<'de>,
106 {
107 Self::from_str(&String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use bytes::BytesMut;
114 use futures_util::stream;
115
116 use super::*;
117
118 #[tokio::test]
119 async fn test_validating_stream() {
120 let input = b"input";
121 let digest = ContentDigest::from(Sha256::new_with_prefix(input));
122 let stream = stream::iter(input.chunks(2));
123 let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
124 assert_eq!(
125 validating.try_collect::<BytesMut>().await.unwrap(),
126 &input[..]
127 );
128 }
129
130 #[tokio::test]
131 async fn test_invalidating_stream() {
132 let input = b"input";
133 let digest = ContentDigest::Sha256 {
134 hex: "doesn't match anything!".to_string(),
135 };
136 let stream = stream::iter(input.chunks(2));
137 let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
138 assert!(matches!(
139 validating.try_collect::<BytesMut>().await,
140 Err(Error::InvalidContent(_)),
141 ));
142 }
143}