drawbridge_type/digest/
reader.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use super::{Algorithm, ContentDigest};
4
5use std::io;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures::AsyncRead;
10use sha2::digest::DynDigest;
11
12/// A hashing reader
13///
14/// This type wraps another reader and hashes the bytes as they are read.
15#[allow(missing_debug_implementations)] // DynDigest does not implement Debug
16pub struct Reader<T> {
17    reader: T,
18    digests: Vec<(Algorithm, Box<dyn DynDigest>)>,
19}
20
21impl<T> Reader<T> {
22    pub(crate) fn new(reader: T, digests: impl IntoIterator<Item = Algorithm>) -> Self {
23        let digests = digests.into_iter().map(|a| (a, a.hasher())).collect();
24        Reader { reader, digests }
25    }
26
27    fn update(&mut self, buf: &[u8]) {
28        for digest in &mut self.digests {
29            digest.1.update(buf);
30        }
31    }
32
33    /// Calculates the digests for all the bytes written so far.
34    pub fn digests(&self) -> ContentDigest<Box<[u8]>> {
35        let mut set = ContentDigest::default();
36
37        for digest in &self.digests {
38            let _ = set.insert(digest.0, digest.1.clone().finalize().into());
39        }
40
41        set
42    }
43}
44
45impl<T: AsyncRead + Unpin> AsyncRead for Reader<T> {
46    fn poll_read(
47        mut self: Pin<&mut Self>,
48        cx: &mut Context<'_>,
49        buf: &mut [u8],
50    ) -> Poll<io::Result<usize>> {
51        Pin::new(&mut self.reader).poll_read(cx, buf).map_ok(|n| {
52            self.update(&buf[..n]);
53            n
54        })
55    }
56}
57
58impl<T: io::Read> io::Read for Reader<T> {
59    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
60        let n = self.reader.read(buf)?;
61        self.update(&buf[..n]);
62        Ok(n)
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use futures::io::{copy, sink};
69
70    use super::*;
71
72    #[async_std::test]
73    async fn success() {
74        const HASH: &str = "sha-256=:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=:";
75        let hash: ContentDigest = HASH.parse().unwrap();
76
77        let mut reader = hash.reader(&b"foo"[..]);
78        assert_eq!(copy(&mut reader, &mut sink()).await.unwrap(), 3);
79        assert_eq!(reader.digests(), hash);
80    }
81
82    #[async_std::test]
83    async fn failure() {
84        const HASH: &str = "sha-256=:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=:";
85        let hash: ContentDigest = HASH.parse().unwrap();
86
87        let mut reader = hash.reader(&b"bar"[..]);
88        assert_eq!(copy(&mut reader, &mut sink()).await.unwrap(), 3);
89        assert_ne!(reader.digests(), hash);
90    }
91}