drawbridge_type/digest/
writer.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use super::{Algorithm, ContentDigest};
4
5use std::io;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures::AsyncWrite;
10use sha2::digest::DynDigest;
11
12/// A hashing writer
13///
14/// This type wraps another writer and hashes the bytes as they are written.
15#[allow(missing_debug_implementations)] // DynDigest does not implement Debug
16pub struct Writer<T> {
17    writer: T,
18    digests: Vec<(Algorithm, Box<dyn DynDigest>)>,
19}
20
21#[allow(unsafe_code)]
22unsafe impl<T> Sync for Writer<T> where T: Sync {}
23
24#[allow(unsafe_code)]
25unsafe impl<T> Send for Writer<T> where T: Send {}
26
27impl<T> Writer<T> {
28    pub(crate) fn new(writer: T, digests: impl IntoIterator<Item = Algorithm>) -> Self {
29        let digests = digests.into_iter().map(|a| (a, a.hasher())).collect();
30        Writer { writer, digests }
31    }
32
33    fn update(&mut self, buf: &[u8]) {
34        for digest in &mut self.digests {
35            digest.1.update(buf);
36        }
37    }
38
39    /// Calculates the digests for all the bytes written so far.
40    pub fn digests(&self) -> ContentDigest<Box<[u8]>> {
41        let mut set = ContentDigest::default();
42
43        for digest in &self.digests {
44            _ = set.insert(digest.0, digest.1.clone().finalize().into());
45        }
46
47        set
48    }
49}
50
51impl<T: AsyncWrite + Unpin> AsyncWrite for Writer<T> {
52    fn poll_write(
53        mut self: Pin<&mut Self>,
54        cx: &mut Context<'_>,
55        buf: &[u8],
56    ) -> Poll<io::Result<usize>> {
57        Pin::new(&mut self.writer).poll_write(cx, buf).map_ok(|n| {
58            self.update(&buf[..n]);
59            n
60        })
61    }
62
63    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
64        Pin::new(&mut self.writer).poll_flush(cx)
65    }
66
67    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68        Pin::new(&mut self.writer).poll_close(cx)
69    }
70}
71
72impl<T: io::Write> io::Write for Writer<T> {
73    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
74        let n = self.writer.write(buf)?;
75        self.update(&buf[..n]);
76        Ok(n)
77    }
78
79    fn flush(&mut self) -> io::Result<()> {
80        self.writer.flush()
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use futures::io::{copy, sink};
87
88    use super::*;
89
90    #[async_std::test]
91    async fn success() {
92        const HASH: &str = "sha-256=:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=:";
93        let set = HASH.parse::<ContentDigest>().unwrap();
94
95        let mut writer = set.clone().writer(sink());
96        assert_eq!(copy(&mut &b"foo"[..], &mut writer).await.unwrap(), 3);
97        assert_eq!(writer.digests(), set);
98    }
99
100    #[async_std::test]
101    async fn failure() {
102        const HASH: &str = "sha-256=:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=:";
103        let set = HASH.parse::<ContentDigest>().unwrap();
104
105        let mut writer = set.clone().writer(sink());
106        assert_eq!(copy(&mut &b"bar"[..], &mut writer).await.unwrap(), 3);
107        assert_ne!(writer.digests(), set);
108    }
109}