wasm_pkg_common/
digest.rs

1use std::str::FromStr;
2
3use bytes::Bytes;
4use futures_util::{future::ready, stream::once, Stream, StreamExt, TryStream, TryStreamExt};
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7
8use crate::Error;
9
10/// A cryptographic digest (hash) of some content.
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub enum ContentDigest {
13    Sha256 { hex: String },
14}
15
16impl ContentDigest {
17    pub fn validating_stream(
18        &self,
19        stream: impl TryStream<Ok = Bytes, Error = Error>,
20    ) -> impl Stream<Item = Result<Bytes, Error>> {
21        let want = self.clone();
22        stream.map_ok(Some).chain(once(async { Ok(None) })).scan(
23            Sha256::new(),
24            move |hasher, res| {
25                ready(match res {
26                    Ok(Some(bytes)) => {
27                        hasher.update(&bytes);
28                        Some(Ok(bytes))
29                    }
30                    Ok(None) => {
31                        let got: Self = std::mem::take(hasher).into();
32                        if got == want {
33                            None
34                        } else {
35                            Some(Err(Error::InvalidContent(format!(
36                                "expected digest {want}, got {got}"
37                            ))))
38                        }
39                    }
40                    Err(err) => Some(Err(err)),
41                })
42            },
43        )
44    }
45}
46
47impl std::fmt::Display for ContentDigest {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            ContentDigest::Sha256 { hex } => write!(f, "sha256:{hex}"),
51        }
52    }
53}
54
55impl From<Sha256> for ContentDigest {
56    fn from(hasher: Sha256) -> Self {
57        Self::Sha256 {
58            hex: format!("{:x}", hasher.finalize()),
59        }
60    }
61}
62
63impl<'a> TryFrom<&'a str> for ContentDigest {
64    type Error = Error;
65
66    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
67        let Some(hex) = value.strip_prefix("sha256:") else {
68            return Err(Error::InvalidContentDigest(
69                "must start with 'sha256:'".into(),
70            ));
71        };
72        let hex = hex.to_lowercase();
73        if hex.len() != 64 {
74            return Err(Error::InvalidContentDigest(format!(
75                "must be 64 hex digits; got {} chars",
76                hex.len()
77            )));
78        }
79        if let Some(invalid) = hex.chars().find(|c| !c.is_ascii_hexdigit()) {
80            return Err(Error::InvalidContentDigest(format!(
81                "must be hex; got {invalid:?}"
82            )));
83        }
84        Ok(Self::Sha256 { hex })
85    }
86}
87
88impl FromStr for ContentDigest {
89    type Err = Error;
90
91    fn from_str(s: &str) -> Result<Self, Self::Err> {
92        s.try_into()
93    }
94}
95
96impl Serialize for ContentDigest {
97    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
98        serializer.serialize_str(&self.to_string())
99    }
100}
101
102impl<'de> Deserialize<'de> for ContentDigest {
103    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
104    where
105        D: serde::Deserializer<'de>,
106    {
107        Self::from_str(&String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use bytes::BytesMut;
114    use futures_util::stream;
115
116    use super::*;
117
118    #[tokio::test]
119    async fn test_validating_stream() {
120        let input = b"input";
121        let digest = ContentDigest::from(Sha256::new_with_prefix(input));
122        let stream = stream::iter(input.chunks(2));
123        let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
124        assert_eq!(
125            validating.try_collect::<BytesMut>().await.unwrap(),
126            &input[..]
127        );
128    }
129
130    #[tokio::test]
131    async fn test_invalidating_stream() {
132        let input = b"input";
133        let digest = ContentDigest::Sha256 {
134            hex: "doesn't match anything!".to_string(),
135        };
136        let stream = stream::iter(input.chunks(2));
137        let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
138        assert!(matches!(
139            validating.try_collect::<BytesMut>().await,
140            Err(Error::InvalidContent(_)),
141        ));
142    }
143}