wasm_pkg_common/
digest.rs

1#[cfg(feature = "tokio")]
2use std::path::Path;
3use std::str::FromStr;
4
5use bytes::Bytes;
6use futures_util::{future::ready, stream::once, Stream, StreamExt, TryStream, TryStreamExt};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9
10use crate::Error;
11
12/// A cryptographic digest (hash) of some content.
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub enum ContentDigest {
15    Sha256 { hex: String },
16}
17
18impl ContentDigest {
19    #[cfg(feature = "tokio")]
20    pub async fn sha256_from_file(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
21        use tokio::io::AsyncReadExt;
22        let mut file = tokio::fs::File::open(path).await?;
23        let mut hasher = Sha256::new();
24        let mut buf = [0; 4096];
25        loop {
26            let n = file.read(&mut buf).await?;
27            if n == 0 {
28                break;
29            }
30            hasher.update(&buf[..n]);
31        }
32        Ok(hasher.into())
33    }
34
35    pub fn validating_stream(
36        &self,
37        stream: impl TryStream<Ok = Bytes, Error = Error>,
38    ) -> impl Stream<Item = Result<Bytes, Error>> {
39        let want = self.clone();
40        stream.map_ok(Some).chain(once(async { Ok(None) })).scan(
41            Sha256::new(),
42            move |hasher, res| {
43                ready(match res {
44                    Ok(Some(bytes)) => {
45                        hasher.update(&bytes);
46                        Some(Ok(bytes))
47                    }
48                    Ok(None) => {
49                        let got: Self = std::mem::take(hasher).into();
50                        if got == want {
51                            None
52                        } else {
53                            Some(Err(Error::InvalidContent(format!(
54                                "expected digest {want}, got {got}"
55                            ))))
56                        }
57                    }
58                    Err(err) => Some(Err(err)),
59                })
60            },
61        )
62    }
63}
64
65impl std::fmt::Display for ContentDigest {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            ContentDigest::Sha256 { hex } => write!(f, "sha256:{hex}"),
69        }
70    }
71}
72
73impl From<Sha256> for ContentDigest {
74    fn from(hasher: Sha256) -> Self {
75        Self::Sha256 {
76            hex: format!("{:x}", hasher.finalize()),
77        }
78    }
79}
80
81impl<'a> TryFrom<&'a str> for ContentDigest {
82    type Error = Error;
83
84    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
85        let Some(hex) = value.strip_prefix("sha256:") else {
86            return Err(Error::InvalidContentDigest(
87                "must start with 'sha256:'".into(),
88            ));
89        };
90        let hex = hex.to_lowercase();
91        if hex.len() != 64 {
92            return Err(Error::InvalidContentDigest(format!(
93                "must be 64 hex digits; got {} chars",
94                hex.len()
95            )));
96        }
97        if let Some(invalid) = hex.chars().find(|c| !c.is_ascii_hexdigit()) {
98            return Err(Error::InvalidContentDigest(format!(
99                "must be hex; got {invalid:?}"
100            )));
101        }
102        Ok(Self::Sha256 { hex })
103    }
104}
105
106impl FromStr for ContentDigest {
107    type Err = Error;
108
109    fn from_str(s: &str) -> Result<Self, Self::Err> {
110        s.try_into()
111    }
112}
113
114impl Serialize for ContentDigest {
115    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
116        serializer.serialize_str(&self.to_string())
117    }
118}
119
120impl<'de> Deserialize<'de> for ContentDigest {
121    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
122    where
123        D: serde::Deserializer<'de>,
124    {
125        Self::from_str(&String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use bytes::BytesMut;
132    use futures_util::stream;
133
134    use super::*;
135
136    #[tokio::test]
137    async fn test_validating_stream() {
138        let input = b"input";
139        let digest = ContentDigest::from(Sha256::new_with_prefix(input));
140        let stream = stream::iter(input.chunks(2));
141        let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
142        assert_eq!(
143            validating.try_collect::<BytesMut>().await.unwrap(),
144            &input[..]
145        );
146    }
147
148    #[tokio::test]
149    async fn test_invalidating_stream() {
150        let input = b"input";
151        let digest = ContentDigest::Sha256 {
152            hex: "doesn't match anything!".to_string(),
153        };
154        let stream = stream::iter(input.chunks(2));
155        let validating = digest.validating_stream(stream.map(|bytes| Ok(bytes.into())));
156        assert!(matches!(
157            validating.try_collect::<BytesMut>().await,
158            Err(Error::InvalidContent(_)),
159        ));
160    }
161}