hashing_reader/
lib.rs

1// Copyright 2023 Cisco Systems, Inc.
2// Use of this source code is governed by an MIT-style
3// license that can be found in the LICENSE file or at
4// https://opensource.org/licenses/MIT.
5
6#![doc = include_str!("../README.md")]
7
8use digest::Digest;
9use pin_project::pin_project;
10use std::io::{self, ErrorKind, Read};
11use std::sync::mpsc::{channel, Receiver, SendError, Sender};
12#[cfg(feature = "tokio")]
13use {
14    std::pin::Pin,
15    std::task::{Context, Poll},
16    tokio::io::AsyncRead,
17};
18
19#[cfg(test)]
20mod test;
21
22#[pin_project]
23pub struct HashingReader<R, H: Digest> {
24    #[pin]
25    reader: R,
26    hasher: H,
27    chan: Sender<Option<Vec<u8>>>,
28}
29
30impl<R, H> HashingReader<R, H>
31where
32    H: Digest,
33{
34    pub fn new(reader: R) -> (Self, Receiver<Option<Vec<u8>>>) {
35        let (tx, rx) = channel::<Option<Vec<u8>>>();
36        let hr: HashingReader<R, H> = HashingReader {
37            reader,
38            hasher: H::new(),
39            chan: tx,
40        };
41        (hr, rx)
42    }
43}
44
45impl<R, H> Read for HashingReader<R, H>
46where
47    R: Read,
48    H: Digest,
49{
50    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
51        let len = match self.reader.read(buf) {
52            Ok(len) => len,
53            Err(e) => {
54                self.chan.send(None).map_err(channel_error)?;
55                return Err(e);
56            }
57        };
58        if len == 0 {
59            let hasher = std::mem::replace(&mut self.hasher, H::new());
60            self.chan
61                .send(Some(hasher.finalize().as_slice().to_vec()))
62                .map_err(channel_error)?;
63        } else {
64            self.hasher.update(&buf[..len]);
65        }
66        Ok(len)
67    }
68}
69
70#[cfg(feature = "tokio")]
71impl<R, H> AsyncRead for HashingReader<R, H>
72where
73    R: AsyncRead + Send + Unpin,
74    H: Digest + digest::Reset,
75{
76    fn poll_read(
77        self: Pin<&mut Self>,
78        cx: &mut Context,
79        buf: &mut tokio::io::ReadBuf<'_>,
80    ) -> std::task::Poll<std::result::Result<(), io::Error>> {
81        let mut this = self.project();
82        let filled_before = buf.filled().len();
83        match this.reader.as_mut().poll_read(cx, buf) {
84            Poll::Ready(Ok(())) => {
85                let filled_after = buf.filled().len();
86                if filled_before == filled_after {
87                    let hasher = std::mem::replace(this.hasher, H::new());
88                    this.chan
89                        .send(Some(hasher.finalize().as_slice().to_vec()))
90                        .map_err(channel_error)?;
91                } else {
92                    let newly_filled = &buf.filled()[filled_before..filled_after];
93                    this.hasher.update(newly_filled);
94                }
95                Poll::Ready(Ok(()))
96            }
97            Poll::Pending => Poll::Pending,
98            Poll::Ready(Err(e)) => {
99                this.chan.send(None).map_err(channel_error)?;
100                Poll::Ready(Err(e))
101            }
102        }
103    }
104}
105
106fn channel_error<T>(e: SendError<T>) -> io::Error {
107    io::Error::new(
108        ErrorKind::Other,
109        format!("EOF reached but was unable to send hash: {:?}", e),
110    )
111}