drawbridge_type/digest/
writer.rs1use super::{Algorithm, ContentDigest};
4
5use std::io;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures::AsyncWrite;
10use sha2::digest::DynDigest;
11
12#[allow(missing_debug_implementations)] pub struct Writer<T> {
17 writer: T,
18 digests: Vec<(Algorithm, Box<dyn DynDigest>)>,
19}
20
21#[allow(unsafe_code)]
22unsafe impl<T> Sync for Writer<T> where T: Sync {}
23
24#[allow(unsafe_code)]
25unsafe impl<T> Send for Writer<T> where T: Send {}
26
27impl<T> Writer<T> {
28 pub(crate) fn new(writer: T, digests: impl IntoIterator<Item = Algorithm>) -> Self {
29 let digests = digests.into_iter().map(|a| (a, a.hasher())).collect();
30 Writer { writer, digests }
31 }
32
33 fn update(&mut self, buf: &[u8]) {
34 for digest in &mut self.digests {
35 digest.1.update(buf);
36 }
37 }
38
39 pub fn digests(&self) -> ContentDigest<Box<[u8]>> {
41 let mut set = ContentDigest::default();
42
43 for digest in &self.digests {
44 _ = set.insert(digest.0, digest.1.clone().finalize().into());
45 }
46
47 set
48 }
49}
50
51impl<T: AsyncWrite + Unpin> AsyncWrite for Writer<T> {
52 fn poll_write(
53 mut self: Pin<&mut Self>,
54 cx: &mut Context<'_>,
55 buf: &[u8],
56 ) -> Poll<io::Result<usize>> {
57 Pin::new(&mut self.writer).poll_write(cx, buf).map_ok(|n| {
58 self.update(&buf[..n]);
59 n
60 })
61 }
62
63 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
64 Pin::new(&mut self.writer).poll_flush(cx)
65 }
66
67 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
68 Pin::new(&mut self.writer).poll_close(cx)
69 }
70}
71
72impl<T: io::Write> io::Write for Writer<T> {
73 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
74 let n = self.writer.write(buf)?;
75 self.update(&buf[..n]);
76 Ok(n)
77 }
78
79 fn flush(&mut self) -> io::Result<()> {
80 self.writer.flush()
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use futures::io::{copy, sink};
87
88 use super::*;
89
90 #[async_std::test]
91 async fn success() {
92 const HASH: &str = "sha-256=:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=:";
93 let set = HASH.parse::<ContentDigest>().unwrap();
94
95 let mut writer = set.clone().writer(sink());
96 assert_eq!(copy(&mut &b"foo"[..], &mut writer).await.unwrap(), 3);
97 assert_eq!(writer.digests(), set);
98 }
99
100 #[async_std::test]
101 async fn failure() {
102 const HASH: &str = "sha-256=:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=:";
103 let set = HASH.parse::<ContentDigest>().unwrap();
104
105 let mut writer = set.clone().writer(sink());
106 assert_eq!(copy(&mut &b"bar"[..], &mut writer).await.unwrap(), 3);
107 assert_ne!(writer.digests(), set);
108 }
109}