drawbridge_hash/
reader.rs

1// SPDX-FileCopyrightText: 2022 Profian Inc. <opensource@profian.com>
2// SPDX-License-Identifier: AGPL-3.0-only
3
4use super::Hash;
5
6use std::io::{self, Error, ErrorKind};
7use std::pin::Pin;
8use std::task::Context;
9
10use futures::AsyncRead;
11use sha2::digest::Digest;
12use sha2::{Sha224, Sha256, Sha384, Sha512};
13
14pub(super) enum Inner {
15    Sha224(Sha224),
16    Sha256(Sha256),
17    Sha384(Sha384),
18    Sha512(Sha512),
19}
20
21pub struct Reader<T> {
22    pub(super) reader: T,
23    pub(super) inner: Inner,
24    pub(super) hash: Hash,
25}
26
27impl<T: AsyncRead + Unpin> AsyncRead for Reader<T> {
28    fn poll_read(
29        mut self: Pin<&mut Self>,
30        cx: &mut Context<'_>,
31        buf: &mut [u8],
32    ) -> std::task::Poll<io::Result<usize>> {
33        Pin::new(&mut self.reader).poll_read(cx, buf).map(|r| {
34            let n = r?;
35
36            match &mut self.inner {
37                Inner::Sha224(h) => h.update(&buf[..n]),
38                Inner::Sha256(h) => h.update(&buf[..n]),
39                Inner::Sha384(h) => h.update(&buf[..n]),
40                Inner::Sha512(h) => h.update(&buf[..n]),
41            };
42
43            // On EOF, validate the hash.
44            if !buf.is_empty() && n == 0 && self.hash() != self.hash {
45                Err(Error::new(ErrorKind::InvalidData, "hash mismatch"))
46            } else {
47                Ok(n)
48            }
49        })
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    use futures::io::{copy, sink};
58
59    #[async_std::test]
60    async fn read_success() {
61        const HASH: &str = "sha256:LCa0a2j_xo_5m0U8HTBBNBNCLXBkg7-g-YpeiGJm564";
62        let hash: Hash = HASH.parse().unwrap();
63        let mut read = hash.reader(&b"foo"[..]);
64        copy(&mut read, &mut sink()).await.unwrap();
65    }
66
67    #[async_std::test]
68    async fn read_failure() {
69        const HASH: &str = "sha256:LCa0a2j_xo_5m0U8HTBBNBNCLXBkg7-g-YpeiGJm564";
70        let hash: Hash = HASH.parse().unwrap();
71        let mut read = hash.reader(&b"bar"[..]);
72        match copy(&mut read, &mut sink()).await {
73            Err(e) => assert_eq!(e.kind(), ErrorKind::InvalidData),
74            Ok(..) => panic!("unexpected success"),
75        }
76    }
77
78    #[async_std::test]
79    async fn meta_hash() {
80        // printf "sha384:%s" $(printf '%s' '{"contentLength":42,"contentType":"text/plain","eTag":"sha384:mqVuAfXRKap7bdgcCY5uykM6-R9GqQ8K_uxy9rx7HNQlGYl1kPzQho1wx4JwY8wC"}' | openssl dgst -sha384 -binary | openssl base64 -A | tr '/' '_' | tr '+' '-')
81        const HASH: &str =
82            "sha384:hF8t6NZNTsnhhFcVjYeIc1kkavoZ3HIaWI_a7Z-l1odHq32xX3YaeFPyo4Jjf6Be";
83        let hash: Hash = HASH.parse().unwrap();
84        let meta = r#"{"contentLength":42,"contentType":"text/plain","eTag":"sha384:mqVuAfXRKap7bdgcCY5uykM6-R9GqQ8K_uxy9rx7HNQlGYl1kPzQho1wx4JwY8wC"}"#;
85        let mut read = hash.reader(meta.as_bytes());
86        copy(&mut read, &mut sink()).await.unwrap();
87    }
88}