fileloft_core/
checksum.rs1use std::pin::Pin;
2use std::str::FromStr;
3use std::task::{Context, Poll};
4
5use base64::{engine::general_purpose::STANDARD, Engine};
6use digest::DynDigest;
7use tokio::io::{AsyncRead, ReadBuf};
8
9use crate::error::TusError;
10use crate::proto::SUPPORTED_CHECKSUM_ALGORITHMS;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ChecksumAlgorithm {
15 Sha1,
16 Sha256,
17 Md5,
18}
19
20impl ChecksumAlgorithm {
21 pub fn as_str(&self) -> &'static str {
22 match self {
23 Self::Sha1 => "sha1",
24 Self::Sha256 => "sha256",
25 Self::Md5 => "md5",
26 }
27 }
28
29 fn make_hasher(&self) -> Box<dyn DynDigest + Send> {
30 match self {
31 Self::Sha1 => Box::new(sha1::Sha1::default()),
32 Self::Sha256 => Box::new(sha2::Sha256::default()),
33 Self::Md5 => Box::new(md5::Md5::default()),
34 }
35 }
36}
37
38impl FromStr for ChecksumAlgorithm {
39 type Err = TusError;
40
41 fn from_str(s: &str) -> Result<Self, Self::Err> {
42 match s.to_lowercase().as_str() {
43 "sha1" => Ok(Self::Sha1),
44 "sha256" => Ok(Self::Sha256),
45 "md5" => Ok(Self::Md5),
46 other => Err(TusError::UnsupportedChecksumAlgorithm(other.to_string())),
47 }
48 }
49}
50
51pub fn algorithms_header() -> String {
53 SUPPORTED_CHECKSUM_ALGORITHMS.join(",")
54}
55
56pub fn parse_checksum_header(value: &str) -> Result<(ChecksumAlgorithm, Vec<u8>), TusError> {
58 let (alg_str, b64) = value.split_once(' ').ok_or_else(|| {
59 TusError::InvalidMetadata("malformed Upload-Checksum header (expected '<alg> <base64>')".into())
60 })?;
61 let algorithm: ChecksumAlgorithm = alg_str.parse()?;
62 let hash = STANDARD
63 .decode(b64.trim())
64 .map_err(|e| TusError::InvalidMetadata(format!("bad base64 in Upload-Checksum: {e}")))?;
65 Ok((algorithm, hash))
66}
67
68pub struct ChecksumReader<R> {
71 inner: R,
72 hasher: Box<dyn DynDigest + Send>,
73 expected: Vec<u8>,
74}
75
76impl<R: AsyncRead + Unpin> ChecksumReader<R> {
77 pub fn new(inner: R, algorithm: ChecksumAlgorithm, expected: Vec<u8>) -> Self {
78 Self {
79 inner,
80 hasher: algorithm.make_hasher(),
81 expected,
82 }
83 }
84
85 pub fn verify(self) -> Result<(), TusError> {
88 let computed = self.hasher.finalize();
89 if computed.as_ref() == self.expected.as_slice() {
90 Ok(())
91 } else {
92 Err(TusError::ChecksumMismatch)
93 }
94 }
95}
96
97impl<R: AsyncRead + Unpin> AsyncRead for ChecksumReader<R> {
98 fn poll_read(
99 self: Pin<&mut Self>,
100 cx: &mut Context<'_>,
101 buf: &mut ReadBuf<'_>,
102 ) -> Poll<std::io::Result<()>> {
103 let me = self.get_mut();
104 let before = buf.filled().len();
105 let result = Pin::new(&mut me.inner).poll_read(cx, buf);
106 if let Poll::Ready(Ok(())) = &result {
107 let filled = &buf.filled()[before..];
108 if !filled.is_empty() {
109 me.hasher.update(filled);
110 }
111 }
112 result
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use std::io::Cursor;
119
120 use base64::{engine::general_purpose::STANDARD, Engine};
121 use sha1::{Digest, Sha1};
122 use tokio::io::AsyncReadExt;
123
124 use super::*;
125
126 fn sha1_b64(data: &[u8]) -> Vec<u8> {
127 let mut h = Sha1::new();
128 Digest::update(&mut h, data);
129 Digest::finalize(h).to_vec()
130 }
131
132 #[tokio::test]
133 async fn checksum_reader_correct_hash() {
134 let data = b"hello tus";
135 let expected = sha1_b64(data);
136 let cursor = Cursor::new(data.to_vec());
137 let mut reader = ChecksumReader::new(cursor, ChecksumAlgorithm::Sha1, expected);
138 let mut buf = Vec::new();
139 reader.read_to_end(&mut buf).await.unwrap();
140 assert_eq!(&buf, data);
141 reader.verify().unwrap();
143 }
144
145 #[tokio::test]
146 async fn checksum_reader_wrong_hash() {
147 let data = b"hello tus";
148 let wrong = vec![0u8; 20]; let cursor = Cursor::new(data.to_vec());
150 let mut reader = ChecksumReader::new(cursor, ChecksumAlgorithm::Sha1, wrong);
151 let mut buf = Vec::new();
152 reader.read_to_end(&mut buf).await.unwrap();
153 assert!(matches!(reader.verify(), Err(TusError::ChecksumMismatch)));
154 }
155
156 #[test]
157 fn parse_checksum_header_sha1() {
158 let data = b"test";
159 let hash = sha1_b64(data);
160 let b64 = STANDARD.encode(&hash);
161 let header = format!("sha1 {b64}");
162 let (alg, decoded) = parse_checksum_header(&header).unwrap();
163 assert_eq!(alg, ChecksumAlgorithm::Sha1);
164 assert_eq!(decoded, hash);
165 }
166
167 #[test]
168 fn parse_checksum_header_unknown_algorithm() {
169 let err = parse_checksum_header("crc32 AAAA").unwrap_err();
170 assert!(matches!(err, TusError::UnsupportedChecksumAlgorithm(_)));
171 }
172
173 #[test]
174 fn parse_checksum_header_bad_base64() {
175 let err = parse_checksum_header("sha1 not_valid!!").unwrap_err();
176 assert!(matches!(err, TusError::InvalidMetadata(_)));
177 }
178
179 #[test]
180 fn parse_checksum_header_missing_space() {
181 let err = parse_checksum_header("sha1").unwrap_err();
182 assert!(matches!(err, TusError::InvalidMetadata(_)));
183 }
184}