Skip to main content

fileloft_core/
checksum.rs

1use std::pin::Pin;
2use std::str::FromStr;
3use std::task::{Context, Poll};
4
5use base64::{engine::general_purpose::STANDARD, Engine};
6use digest::DynDigest;
7use tokio::io::{AsyncRead, ReadBuf};
8
9use crate::error::TusError;
10use crate::proto::SUPPORTED_CHECKSUM_ALGORITHMS;
11
12/// Checksum algorithm supported by the tus checksum extension.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ChecksumAlgorithm {
15    Sha1,
16    Sha256,
17    Md5,
18}
19
20impl ChecksumAlgorithm {
21    pub fn as_str(&self) -> &'static str {
22        match self {
23            Self::Sha1 => "sha1",
24            Self::Sha256 => "sha256",
25            Self::Md5 => "md5",
26        }
27    }
28
29    fn make_hasher(&self) -> Box<dyn DynDigest + Send> {
30        match self {
31            Self::Sha1 => Box::new(sha1::Sha1::default()),
32            Self::Sha256 => Box::new(sha2::Sha256::default()),
33            Self::Md5 => Box::new(md5::Md5::default()),
34        }
35    }
36}
37
38impl FromStr for ChecksumAlgorithm {
39    type Err = TusError;
40
41    fn from_str(s: &str) -> Result<Self, Self::Err> {
42        match s.to_lowercase().as_str() {
43            "sha1" => Ok(Self::Sha1),
44            "sha256" => Ok(Self::Sha256),
45            "md5" => Ok(Self::Md5),
46            other => Err(TusError::UnsupportedChecksumAlgorithm(other.to_string())),
47        }
48    }
49}
50
51/// Comma-separated list of algorithms to advertise in `Tus-Checksum-Algorithm`.
52pub fn algorithms_header() -> String {
53    SUPPORTED_CHECKSUM_ALGORITHMS.join(",")
54}
55
56/// Parse the `Upload-Checksum` header: `"<algorithm> <base64-hash>"`.
57pub fn parse_checksum_header(value: &str) -> Result<(ChecksumAlgorithm, Vec<u8>), TusError> {
58    let (alg_str, b64) = value.split_once(' ').ok_or_else(|| {
59        TusError::InvalidMetadata("malformed Upload-Checksum header (expected '<alg> <base64>')".into())
60    })?;
61    let algorithm: ChecksumAlgorithm = alg_str.parse()?;
62    let hash = STANDARD
63        .decode(b64.trim())
64        .map_err(|e| TusError::InvalidMetadata(format!("bad base64 in Upload-Checksum: {e}")))?;
65    Ok((algorithm, hash))
66}
67
68/// Wraps any `AsyncRead`, feeding bytes through a hasher as they pass through.
69/// Call `verify()` after all bytes have been read to check against the expected hash.
70pub struct ChecksumReader<R> {
71    inner: R,
72    hasher: Box<dyn DynDigest + Send>,
73    expected: Vec<u8>,
74}
75
76impl<R: AsyncRead + Unpin> ChecksumReader<R> {
77    pub fn new(inner: R, algorithm: ChecksumAlgorithm, expected: Vec<u8>) -> Self {
78        Self {
79            inner,
80            hasher: algorithm.make_hasher(),
81            expected,
82        }
83    }
84
85    /// Compare the accumulated hash against the expected value.
86    /// Call this *after* all bytes have been read through this reader.
87    pub fn verify(self) -> Result<(), TusError> {
88        let computed = self.hasher.finalize();
89        if computed.as_ref() == self.expected.as_slice() {
90            Ok(())
91        } else {
92            Err(TusError::ChecksumMismatch)
93        }
94    }
95}
96
97impl<R: AsyncRead + Unpin> AsyncRead for ChecksumReader<R> {
98    fn poll_read(
99        self: Pin<&mut Self>,
100        cx: &mut Context<'_>,
101        buf: &mut ReadBuf<'_>,
102    ) -> Poll<std::io::Result<()>> {
103        let me = self.get_mut();
104        let before = buf.filled().len();
105        let result = Pin::new(&mut me.inner).poll_read(cx, buf);
106        if let Poll::Ready(Ok(())) = &result {
107            let filled = &buf.filled()[before..];
108            if !filled.is_empty() {
109                me.hasher.update(filled);
110            }
111        }
112        result
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use std::io::Cursor;
119
120    use base64::{engine::general_purpose::STANDARD, Engine};
121    use sha1::{Digest, Sha1};
122    use tokio::io::AsyncReadExt;
123
124    use super::*;
125
126    fn sha1_b64(data: &[u8]) -> Vec<u8> {
127        let mut h = Sha1::new();
128        Digest::update(&mut h, data);
129        Digest::finalize(h).to_vec()
130    }
131
132    #[tokio::test]
133    async fn checksum_reader_correct_hash() {
134        let data = b"hello tus";
135        let expected = sha1_b64(data);
136        let cursor = Cursor::new(data.to_vec());
137        let mut reader = ChecksumReader::new(cursor, ChecksumAlgorithm::Sha1, expected);
138        let mut buf = Vec::new();
139        reader.read_to_end(&mut buf).await.unwrap();
140        assert_eq!(&buf, data);
141        // Verify must succeed
142        reader.verify().unwrap();
143    }
144
145    #[tokio::test]
146    async fn checksum_reader_wrong_hash() {
147        let data = b"hello tus";
148        let wrong = vec![0u8; 20]; // wrong SHA1 (all zeros)
149        let cursor = Cursor::new(data.to_vec());
150        let mut reader = ChecksumReader::new(cursor, ChecksumAlgorithm::Sha1, wrong);
151        let mut buf = Vec::new();
152        reader.read_to_end(&mut buf).await.unwrap();
153        assert!(matches!(reader.verify(), Err(TusError::ChecksumMismatch)));
154    }
155
156    #[test]
157    fn parse_checksum_header_sha1() {
158        let data = b"test";
159        let hash = sha1_b64(data);
160        let b64 = STANDARD.encode(&hash);
161        let header = format!("sha1 {b64}");
162        let (alg, decoded) = parse_checksum_header(&header).unwrap();
163        assert_eq!(alg, ChecksumAlgorithm::Sha1);
164        assert_eq!(decoded, hash);
165    }
166
167    #[test]
168    fn parse_checksum_header_unknown_algorithm() {
169        let err = parse_checksum_header("crc32 AAAA").unwrap_err();
170        assert!(matches!(err, TusError::UnsupportedChecksumAlgorithm(_)));
171    }
172
173    #[test]
174    fn parse_checksum_header_bad_base64() {
175        let err = parse_checksum_header("sha1 not_valid!!").unwrap_err();
176        assert!(matches!(err, TusError::InvalidMetadata(_)));
177    }
178
179    #[test]
180    fn parse_checksum_header_missing_space() {
181        let err = parse_checksum_header("sha1").unwrap_err();
182        assert!(matches!(err, TusError::InvalidMetadata(_)));
183    }
184}