fileloft_core/
checksum.rs1use std::pin::Pin;
2use std::str::FromStr;
3use std::task::{Context, Poll};
4
5use base64::{engine::general_purpose::STANDARD, Engine};
6use digest::DynDigest;
7use tokio::io::{AsyncRead, ReadBuf};
8
9use crate::error::TusError;
10use crate::proto::SUPPORTED_CHECKSUM_ALGORITHMS;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ChecksumAlgorithm {
15 Sha1,
16 Sha256,
17 Md5,
18}
19
20impl ChecksumAlgorithm {
21 pub fn as_str(&self) -> &'static str {
22 match self {
23 Self::Sha1 => "sha1",
24 Self::Sha256 => "sha256",
25 Self::Md5 => "md5",
26 }
27 }
28
29 fn make_hasher(&self) -> Box<dyn DynDigest + Send> {
30 match self {
31 Self::Sha1 => Box::new(sha1::Sha1::default()),
32 Self::Sha256 => Box::new(sha2::Sha256::default()),
33 Self::Md5 => Box::new(md5::Md5::default()),
34 }
35 }
36}
37
38impl FromStr for ChecksumAlgorithm {
39 type Err = TusError;
40
41 fn from_str(s: &str) -> Result<Self, Self::Err> {
42 match s.to_lowercase().as_str() {
43 "sha1" => Ok(Self::Sha1),
44 "sha256" => Ok(Self::Sha256),
45 "md5" => Ok(Self::Md5),
46 other => Err(TusError::UnsupportedChecksumAlgorithm(other.to_string())),
47 }
48 }
49}
50
51pub fn algorithms_header() -> String {
53 SUPPORTED_CHECKSUM_ALGORITHMS.join(",")
54}
55
56pub fn parse_checksum_header(value: &str) -> Result<(ChecksumAlgorithm, Vec<u8>), TusError> {
58 let (alg_str, b64) = value.split_once(' ').ok_or_else(|| {
59 TusError::InvalidMetadata(
60 "malformed Upload-Checksum header (expected '<alg> <base64>')".into(),
61 )
62 })?;
63 let algorithm: ChecksumAlgorithm = alg_str.parse()?;
64 let hash = STANDARD
65 .decode(b64.trim())
66 .map_err(|e| TusError::InvalidMetadata(format!("bad base64 in Upload-Checksum: {e}")))?;
67 Ok((algorithm, hash))
68}
69
70pub struct ChecksumReader<R> {
73 inner: R,
74 hasher: Box<dyn DynDigest + Send>,
75 expected: Vec<u8>,
76}
77
78impl<R: AsyncRead + Unpin> ChecksumReader<R> {
79 pub fn new(inner: R, algorithm: ChecksumAlgorithm, expected: Vec<u8>) -> Self {
80 Self {
81 inner,
82 hasher: algorithm.make_hasher(),
83 expected,
84 }
85 }
86
87 pub fn verify(self) -> Result<(), TusError> {
90 let computed = self.hasher.finalize();
91 if computed.as_ref() == self.expected.as_slice() {
92 Ok(())
93 } else {
94 Err(TusError::ChecksumMismatch)
95 }
96 }
97}
98
99impl<R: AsyncRead + Unpin> AsyncRead for ChecksumReader<R> {
100 fn poll_read(
101 self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 buf: &mut ReadBuf<'_>,
104 ) -> Poll<std::io::Result<()>> {
105 let me = self.get_mut();
106 let before = buf.filled().len();
107 let result = Pin::new(&mut me.inner).poll_read(cx, buf);
108 if let Poll::Ready(Ok(())) = &result {
109 let filled = &buf.filled()[before..];
110 if !filled.is_empty() {
111 me.hasher.update(filled);
112 }
113 }
114 result
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use std::io::Cursor;
121
122 use base64::{engine::general_purpose::STANDARD, Engine};
123 use sha1::{Digest, Sha1};
124 use tokio::io::AsyncReadExt;
125
126 use super::*;
127
128 fn sha1_b64(data: &[u8]) -> Vec<u8> {
129 let mut h = Sha1::new();
130 Digest::update(&mut h, data);
131 Digest::finalize(h).to_vec()
132 }
133
134 #[tokio::test]
135 async fn checksum_reader_correct_hash() {
136 let data = b"hello tus";
137 let expected = sha1_b64(data);
138 let cursor = Cursor::new(data.to_vec());
139 let mut reader = ChecksumReader::new(cursor, ChecksumAlgorithm::Sha1, expected);
140 let mut buf = Vec::new();
141 reader.read_to_end(&mut buf).await.unwrap();
142 assert_eq!(&buf, data);
143 reader.verify().unwrap();
145 }
146
147 #[tokio::test]
148 async fn checksum_reader_wrong_hash() {
149 let data = b"hello tus";
150 let wrong = vec![0u8; 20]; let cursor = Cursor::new(data.to_vec());
152 let mut reader = ChecksumReader::new(cursor, ChecksumAlgorithm::Sha1, wrong);
153 let mut buf = Vec::new();
154 reader.read_to_end(&mut buf).await.unwrap();
155 assert!(matches!(reader.verify(), Err(TusError::ChecksumMismatch)));
156 }
157
158 #[test]
159 fn parse_checksum_header_sha1() {
160 let data = b"test";
161 let hash = sha1_b64(data);
162 let b64 = STANDARD.encode(&hash);
163 let header = format!("sha1 {b64}");
164 let (alg, decoded) = parse_checksum_header(&header).unwrap();
165 assert_eq!(alg, ChecksumAlgorithm::Sha1);
166 assert_eq!(decoded, hash);
167 }
168
169 #[test]
170 fn parse_checksum_header_unknown_algorithm() {
171 let err = parse_checksum_header("crc32 AAAA").unwrap_err();
172 assert!(matches!(err, TusError::UnsupportedChecksumAlgorithm(_)));
173 }
174
175 #[test]
176 fn parse_checksum_header_bad_base64() {
177 let err = parse_checksum_header("sha1 not_valid!!").unwrap_err();
178 assert!(matches!(err, TusError::InvalidMetadata(_)));
179 }
180
181 #[test]
182 fn parse_checksum_header_missing_space() {
183 let err = parse_checksum_header("sha1").unwrap_err();
184 assert!(matches!(err, TusError::InvalidMetadata(_)));
185 }
186}