drawbridge_type/digest/
verifier.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use super::{ContentDigest, Reader};
4
5use std::io::{self, Error, ErrorKind};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures::AsyncRead;
10
11/// A verifying reader
12///
13/// This type is exactly the same as [`Reader`](super::Reader) except that it
14/// additionally verifies the expected hashes. When the end-of-file condition
15/// is reached, if the actual hashes do not match the expected hashes, an error
16/// is produced.
17#[allow(missing_debug_implementations)] // Reader does not implement Debug
18pub struct Verifier<T, H>
19where
20    H: AsRef<[u8]> + From<Vec<u8>>,
21{
22    reader: Reader<T>,
23    hashes: ContentDigest<H>,
24}
25
26#[allow(unsafe_code)]
27unsafe impl<T, H> Sync for Verifier<T, H>
28where
29    T: Sync,
30    H: Sync + AsRef<[u8]> + From<Vec<u8>>,
31{
32}
33
34#[allow(unsafe_code)]
35unsafe impl<T, H> Send for Verifier<T, H>
36where
37    T: Send,
38    H: Send + AsRef<[u8]> + From<Vec<u8>>,
39{
40}
41
42impl<T, H> Verifier<T, H>
43where
44    H: AsRef<[u8]> + From<Vec<u8>>,
45{
46    pub(crate) fn new(reader: Reader<T>, hashes: ContentDigest<H>) -> Self {
47        Self { reader, hashes }
48    }
49
50    pub fn digests(&self) -> ContentDigest<Box<[u8]>> {
51        self.reader.digests()
52    }
53}
54
55impl<T: Unpin, H> Unpin for Verifier<T, H> where H: AsRef<[u8]> + From<Vec<u8>> {}
56
57impl<T: AsyncRead + Unpin, H> AsyncRead for Verifier<T, H>
58where
59    H: AsRef<[u8]> + From<Vec<u8>>,
60{
61    fn poll_read(
62        mut self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        buf: &mut [u8],
65    ) -> Poll<io::Result<usize>> {
66        Pin::new(&mut self.reader)
67            .poll_read(cx, buf)
68            .map(|r| match r? {
69                0 if self.reader.digests() != self.hashes => {
70                    Err(Error::new(ErrorKind::InvalidData, "hash mismatch"))
71                }
72                n => Ok(n),
73            })
74    }
75}
76
77impl<T: io::Read, H> io::Read for Verifier<T, H>
78where
79    H: AsRef<[u8]> + From<Vec<u8>>,
80{
81    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
82        match self.reader.read(buf)? {
83            0 if self.reader.digests() != self.hashes => {
84                Err(Error::new(ErrorKind::InvalidData, "hash mismatch"))
85            }
86            n => Ok(n),
87        }
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use futures::io::{copy, sink};
94
95    use super::*;
96
97    #[async_std::test]
98    async fn read_success() {
99        let rdr = &b"foo"[..];
100        let content_digest = "sha-224=:CAj2TmDViXn8tnbJbsk4Jw3qQkRa7vzTpOb42w==:,sha-256=:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=:,sha-384=:mMEf/f3VQGdrGhN8saIrKnA1DJpEFx1rEYDGvly7LuP3nVMsih3Z7y6OCOdSo7q7:,sha-512=:9/u6bgY2+JDlb7vzKD5STG+jIErimDgtYkdB0NxmODJuKCxBvl5CVNiCB3LFUYosWowMf37aGVlKfrU5RT4e1w==:"
101                .parse::<ContentDigest>()
102                .unwrap();
103
104        assert_eq!(
105            copy(&mut content_digest.clone().verifier(rdr), &mut sink())
106                .await
107                .unwrap(),
108            "foo".len() as u64,
109        );
110        assert_eq!(
111            std::io::copy(&mut content_digest.verifier(rdr), &mut std::io::sink()).unwrap(),
112            "foo".len() as u64,
113        );
114    }
115
116    #[async_std::test]
117    async fn read_failure() {
118        let rdr = &b"bar"[..];
119        let content_digest = "sha-224=:CAj2TmDViXn8tnbJbsk4Jw3qQkRa7vzTpOb42w==:,sha-256=:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=:,sha-384=:mMEf/f3VQGdrGhN8saIrKnA1DJpEFx1rEYDGvly7LuP3nVMsih3Z7y6OCOdSo7q7:,sha-512=:9/u6bgY2+JDlb7vzKD5STG+jIErimDgtYkdB0NxmODJuKCxBvl5CVNiCB3LFUYosWowMf37aGVlKfrU5RT4e1w==:"
120                .parse::<ContentDigest>()
121                .unwrap();
122
123        assert_eq!(
124            copy(&mut content_digest.clone().verifier(rdr), &mut sink())
125                .await
126                .unwrap_err()
127                .kind(),
128            ErrorKind::InvalidData,
129        );
130        assert_eq!(
131            std::io::copy(&mut content_digest.verifier(rdr), &mut std::io::sink())
132                .unwrap_err()
133                .kind(),
134            ErrorKind::InvalidData,
135        );
136    }
137}