1use std::{
2 fmt,
3 io::{self, Read},
4};
5
6use sha2::Digest as _;
7
8#[derive(Clone, Copy, Debug, PartialEq)]
12#[non_exhaustive]
13pub enum DigestAlgorithm {
14 SHA256,
15 SHA512,
16}
17
18#[derive(Clone, Debug, PartialEq, serde::Deserialize)]
34#[serde(try_from = "String")]
35pub struct Digest {
36 hash: String,
37 algorithm: DigestAlgorithm,
38}
39
40#[derive(thiserror::Error, Debug)]
42pub enum DigestError {
43 #[error("Invalid digest algorithm.")]
44 InvalidAlgorithm,
45
46 #[error("Invalid digest value.")]
47 InvalidValue,
48}
49
50impl Digest {
51 pub fn source(&self) -> &str {
53 &self.hash
54 }
55
56 pub fn hash_value(&self) -> &str {
57 self.hash
58 .split_once(':')
59 .map(|(_, h)| h)
60 .unwrap_or_default()
61 }
62
63 pub fn algorithm(&self) -> DigestAlgorithm {
64 self.algorithm
65 }
66
67 pub fn wrap_reader<R: Read>(&self, reader: R) -> impl Read {
74 let hasher: Box<dyn digest::DynDigest> = match self.algorithm {
75 DigestAlgorithm::SHA256 => Box::new(sha2::Sha256::new()),
76 DigestAlgorithm::SHA512 => Box::new(sha2::Sha512::new()),
77 };
78
79 DigestReader {
80 hasher,
81 expected: self.hash_value().to_owned(),
82 reader,
83 }
84 }
85}
86
87impl TryFrom<String> for Digest {
88 type Error = DigestError;
89
90 fn try_from(hash: String) -> Result<Self, Self::Error> {
91 let (algorithm, value, expected_size) = {
92 if let Some(h) = hash.strip_prefix("sha256:") {
93 (DigestAlgorithm::SHA256, h, 256 / 8 * 2)
94 } else if let Some(h) = hash.strip_prefix("sha512:") {
95 (DigestAlgorithm::SHA512, h, 512 / 8 * 2)
96 } else {
97 return Err(DigestError::InvalidAlgorithm);
98 }
99 };
100
101 if value.len() == expected_size && value.chars().all(|c| c.is_ascii_hexdigit()) {
104 Ok(Digest { hash, algorithm })
105 } else {
106 Err(DigestError::InvalidValue)
107 }
108 }
109}
110
111struct DigestReader<R> {
112 hasher: Box<dyn digest::DynDigest>,
113 expected: String,
114 reader: R,
115}
116
117impl<R: Read> Read for DigestReader<R> {
118 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
119 let buf_len = buf.len();
120 let n = self.reader.read(buf)?;
121
122 if n == 0 && buf_len > 0 {
123 return self.check_hash();
125 }
126
127 self.hasher.update(&buf[..n]);
128
129 Ok(n)
130 }
131}
132
133impl<R> DigestReader<R> {
134 fn check_hash(&mut self) -> io::Result<usize> {
135 const MAX_DIGEST_SIZE: usize = 512 / 8;
136
137 debug_assert_eq!(self.hasher.output_size() * 2, self.expected.len());
138
139 let mut buffer = [0u8; MAX_DIGEST_SIZE];
140 let out = &mut buffer[..self.hasher.output_size()];
141
142 self.hasher
143 .finalize_into_reset(out)
144 .map_err(io::Error::other)?;
145
146 let mut expected = self.expected.as_str();
147 for hash_byte in out.iter() {
148 match expected
149 .split_at_checked(2)
150 .map(|(b, t)| (u8::from_str_radix(b, 16), t))
151 {
152 Some((Ok(byte), t)) if byte == *hash_byte => expected = t,
153
154 _ => {
155 return Err(io::Error::new(
156 io::ErrorKind::InvalidData,
157 format!(
158 "Invalid digest. Expected {}, got {}.",
159 self.expected,
160 HexString(out)
161 ),
162 ))
163 }
164 }
165 }
166
167 Ok(0)
168 }
169}
170
171pub(crate) struct HexString<T>(pub T);
173
174impl<T: AsRef<[u8]>> fmt::Display for HexString<T> {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 self.0
177 .as_ref()
178 .iter()
179 .try_for_each(|byte| write!(f, "{:02x}", byte))
180 }
181}
182
183#[test]
184fn encode_hex_bytes() {
185 assert_eq!(HexString(b"\x01\x20\xf0").to_string(), "0120f0");
186}
187
188#[test]
189fn reject_invalid_digest() {
190 use std::io::Cursor;
191
192 const DIGEST: &str = "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad";
194
195 let digest = Digest::try_from(format!("sha256:{DIGEST}")).unwrap();
196 let mut output = Vec::new();
197
198 digest
200 .wrap_reader(Cursor::new("abc"))
201 .read_to_end(&mut output)
202 .unwrap();
203
204 assert_eq!(output, b"abc");
205
206 output.clear();
208 let err = digest
209 .wrap_reader(Cursor::new("abcx"))
210 .read_to_end(&mut output)
211 .unwrap_err();
212
213 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
214
215 let msg = err.into_inner().unwrap().to_string().to_lowercase();
216 assert!(msg.contains(DIGEST));
217 assert!(msg.contains("7571ce1f8e21c6b13dd7ec2c5ec7c9e4dd9852e209869511853f2f1f74b17927"));
218}