Skip to main content

fileloft_core/
checksum.rs

1use std::pin::Pin;
2use std::str::FromStr;
3use std::task::{Context, Poll};
4
5use base64::{engine::general_purpose::STANDARD, Engine};
6use digest::DynDigest;
7use tokio::io::{AsyncRead, ReadBuf};
8
9use crate::error::TusError;
10use crate::proto::SUPPORTED_CHECKSUM_ALGORITHMS;
11
12/// Checksum algorithm supported by the tus checksum extension.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ChecksumAlgorithm {
15    Sha1,
16    Sha256,
17    Md5,
18}
19
20impl ChecksumAlgorithm {
21    pub fn as_str(&self) -> &'static str {
22        match self {
23            Self::Sha1 => "sha1",
24            Self::Sha256 => "sha256",
25            Self::Md5 => "md5",
26        }
27    }
28
29    fn make_hasher(&self) -> Box<dyn DynDigest + Send> {
30        match self {
31            Self::Sha1 => Box::new(sha1::Sha1::default()),
32            Self::Sha256 => Box::new(sha2::Sha256::default()),
33            Self::Md5 => Box::new(md5::Md5::default()),
34        }
35    }
36}
37
38impl FromStr for ChecksumAlgorithm {
39    type Err = TusError;
40
41    fn from_str(s: &str) -> Result<Self, Self::Err> {
42        match s.to_lowercase().as_str() {
43            "sha1" => Ok(Self::Sha1),
44            "sha256" => Ok(Self::Sha256),
45            "md5" => Ok(Self::Md5),
46            other => Err(TusError::UnsupportedChecksumAlgorithm(other.to_string())),
47        }
48    }
49}
50
51/// Comma-separated list of algorithms to advertise in `Tus-Checksum-Algorithm`.
52pub fn algorithms_header() -> String {
53    SUPPORTED_CHECKSUM_ALGORITHMS.join(",")
54}
55
56/// Parse the `Upload-Checksum` header: `"<algorithm> <base64-hash>"`.
57pub fn parse_checksum_header(value: &str) -> Result<(ChecksumAlgorithm, Vec<u8>), TusError> {
58    let (alg_str, b64) = value.split_once(' ').ok_or_else(|| {
59        TusError::InvalidMetadata(
60            "malformed Upload-Checksum header (expected '<alg> <base64>')".into(),
61        )
62    })?;
63    let algorithm: ChecksumAlgorithm = alg_str.parse()?;
64    let hash = STANDARD
65        .decode(b64.trim())
66        .map_err(|e| TusError::InvalidMetadata(format!("bad base64 in Upload-Checksum: {e}")))?;
67    Ok((algorithm, hash))
68}
69
70/// Wraps any `AsyncRead`, feeding bytes through a hasher as they pass through.
71/// Call `verify()` after all bytes have been read to check against the expected hash.
72pub struct ChecksumReader<R> {
73    inner: R,
74    hasher: Box<dyn DynDigest + Send>,
75    expected: Vec<u8>,
76}
77
78impl<R: AsyncRead + Unpin> ChecksumReader<R> {
79    pub fn new(inner: R, algorithm: ChecksumAlgorithm, expected: Vec<u8>) -> Self {
80        Self {
81            inner,
82            hasher: algorithm.make_hasher(),
83            expected,
84        }
85    }
86
87    /// Compare the accumulated hash against the expected value.
88    /// Call this *after* all bytes have been read through this reader.
89    pub fn verify(self) -> Result<(), TusError> {
90        let computed = self.hasher.finalize();
91        if computed.as_ref() == self.expected.as_slice() {
92            Ok(())
93        } else {
94            Err(TusError::ChecksumMismatch)
95        }
96    }
97}
98
99impl<R: AsyncRead + Unpin> AsyncRead for ChecksumReader<R> {
100    fn poll_read(
101        self: Pin<&mut Self>,
102        cx: &mut Context<'_>,
103        buf: &mut ReadBuf<'_>,
104    ) -> Poll<std::io::Result<()>> {
105        let me = self.get_mut();
106        let before = buf.filled().len();
107        let result = Pin::new(&mut me.inner).poll_read(cx, buf);
108        if let Poll::Ready(Ok(())) = &result {
109            let filled = &buf.filled()[before..];
110            if !filled.is_empty() {
111                me.hasher.update(filled);
112            }
113        }
114        result
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use std::io::Cursor;
121
122    use base64::{engine::general_purpose::STANDARD, Engine};
123    use sha1::{Digest, Sha1};
124    use tokio::io::AsyncReadExt;
125
126    use super::*;
127
128    fn sha1_b64(data: &[u8]) -> Vec<u8> {
129        let mut h = Sha1::new();
130        Digest::update(&mut h, data);
131        Digest::finalize(h).to_vec()
132    }
133
134    #[tokio::test]
135    async fn checksum_reader_correct_hash() {
136        let data = b"hello tus";
137        let expected = sha1_b64(data);
138        let cursor = Cursor::new(data.to_vec());
139        let mut reader = ChecksumReader::new(cursor, ChecksumAlgorithm::Sha1, expected);
140        let mut buf = Vec::new();
141        reader.read_to_end(&mut buf).await.unwrap();
142        assert_eq!(&buf, data);
143        // Verify must succeed
144        reader.verify().unwrap();
145    }
146
147    #[tokio::test]
148    async fn checksum_reader_wrong_hash() {
149        let data = b"hello tus";
150        let wrong = vec![0u8; 20]; // wrong SHA1 (all zeros)
151        let cursor = Cursor::new(data.to_vec());
152        let mut reader = ChecksumReader::new(cursor, ChecksumAlgorithm::Sha1, wrong);
153        let mut buf = Vec::new();
154        reader.read_to_end(&mut buf).await.unwrap();
155        assert!(matches!(reader.verify(), Err(TusError::ChecksumMismatch)));
156    }
157
158    #[test]
159    fn parse_checksum_header_sha1() {
160        let data = b"test";
161        let hash = sha1_b64(data);
162        let b64 = STANDARD.encode(&hash);
163        let header = format!("sha1 {b64}");
164        let (alg, decoded) = parse_checksum_header(&header).unwrap();
165        assert_eq!(alg, ChecksumAlgorithm::Sha1);
166        assert_eq!(decoded, hash);
167    }
168
169    #[test]
170    fn parse_checksum_header_unknown_algorithm() {
171        let err = parse_checksum_header("crc32 AAAA").unwrap_err();
172        assert!(matches!(err, TusError::UnsupportedChecksumAlgorithm(_)));
173    }
174
175    #[test]
176    fn parse_checksum_header_bad_base64() {
177        let err = parse_checksum_header("sha1 not_valid!!").unwrap_err();
178        assert!(matches!(err, TusError::InvalidMetadata(_)));
179    }
180
181    #[test]
182    fn parse_checksum_header_missing_space() {
183        let err = parse_checksum_header("sha1").unwrap_err();
184        assert!(matches!(err, TusError::InvalidMetadata(_)));
185    }
186}