1use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256, Sha384, Sha512};
8use sha3::{Sha3_256, Sha3_512};
9use std::fmt;
10use std::io::Read;
11use std::str::FromStr;
12
13use crate::{Error, Result};
14
15#[derive(
17 Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default, strum::Display,
18)]
19#[serde(rename_all = "lowercase")]
20pub enum HashAlgorithm {
21 #[default]
23 #[serde(rename = "sha256")]
24 #[strum(serialize = "sha256")]
25 Sha256,
26
27 #[serde(rename = "sha384")]
29 #[strum(serialize = "sha384")]
30 Sha384,
31
32 #[serde(rename = "sha512")]
34 #[strum(serialize = "sha512")]
35 Sha512,
36
37 #[serde(rename = "sha3-256")]
39 #[strum(serialize = "sha3-256")]
40 Sha3_256,
41
42 #[serde(rename = "sha3-512")]
44 #[strum(serialize = "sha3-512")]
45 Sha3_512,
46
47 #[serde(rename = "blake3")]
49 #[strum(serialize = "blake3")]
50 Blake3,
51}
52
53impl HashAlgorithm {
54 #[must_use]
56 pub const fn as_str(&self) -> &'static str {
57 match self {
58 Self::Sha256 => "sha256",
59 Self::Sha384 => "sha384",
60 Self::Sha512 => "sha512",
61 Self::Sha3_256 => "sha3-256",
62 Self::Sha3_512 => "sha3-512",
63 Self::Blake3 => "blake3",
64 }
65 }
66
67 #[must_use]
69 pub const fn output_size(&self) -> usize {
70 match self {
71 Self::Sha256 | Self::Sha3_256 | Self::Blake3 => 32,
72 Self::Sha384 => 48,
73 Self::Sha512 | Self::Sha3_512 => 64,
74 }
75 }
76}
77
78impl FromStr for HashAlgorithm {
79 type Err = Error;
80
81 fn from_str(s: &str) -> Result<Self> {
82 match s.to_lowercase().as_str() {
83 "sha256" => Ok(Self::Sha256),
84 "sha384" => Ok(Self::Sha384),
85 "sha512" => Ok(Self::Sha512),
86 "sha3-256" => Ok(Self::Sha3_256),
87 "sha3-512" => Ok(Self::Sha3_512),
88 "blake3" => Ok(Self::Blake3),
89 _ => Err(Error::UnsupportedHashAlgorithm {
90 algorithm: s.to_string(),
91 }),
92 }
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Eq, Hash)]
109pub struct DocumentId {
110 algorithm: HashAlgorithm,
111 digest: Vec<u8>,
112}
113
114impl DocumentId {
115 #[must_use]
117 pub fn new(algorithm: HashAlgorithm, digest: Vec<u8>) -> Self {
118 Self { algorithm, digest }
119 }
120
121 #[must_use]
123 pub const fn algorithm(&self) -> HashAlgorithm {
124 self.algorithm
125 }
126
127 #[must_use]
129 pub fn digest(&self) -> &[u8] {
130 &self.digest
131 }
132
133 #[must_use]
135 pub fn hex_digest(&self) -> String {
136 hex_encode(&self.digest)
137 }
138
139 #[must_use]
141 pub fn is_pending(&self) -> bool {
142 self.digest.is_empty()
143 }
144
145 #[must_use]
147 pub fn pending() -> Self {
148 Self {
149 algorithm: HashAlgorithm::default(),
150 digest: Vec::new(),
151 }
152 }
153}
154
155impl fmt::Display for DocumentId {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 if self.is_pending() {
158 write!(f, "pending")
159 } else {
160 write!(f, "{}:{}", self.algorithm, self.hex_digest())
161 }
162 }
163}
164
165impl FromStr for DocumentId {
166 type Err = Error;
167
168 fn from_str(s: &str) -> Result<Self> {
169 if s == "pending" {
170 return Ok(Self::pending());
171 }
172
173 let (alg_str, hex_str) = s.split_once(':').ok_or_else(|| Error::InvalidHashFormat {
174 value: s.to_string(),
175 })?;
176
177 let algorithm: HashAlgorithm = alg_str.parse()?;
178 let digest = hex_decode(hex_str).map_err(|()| Error::InvalidHashFormat {
179 value: s.to_string(),
180 })?;
181
182 if digest.len() != algorithm.output_size() {
184 return Err(Error::InvalidHashFormat {
185 value: s.to_string(),
186 });
187 }
188
189 Ok(Self { algorithm, digest })
190 }
191}
192
193impl Serialize for DocumentId {
194 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
195 where
196 S: serde::Serializer,
197 {
198 serializer.serialize_str(&self.to_string())
199 }
200}
201
202impl<'de> Deserialize<'de> for DocumentId {
203 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
204 where
205 D: serde::Deserializer<'de>,
206 {
207 let s = String::deserialize(deserializer)?;
208 s.parse().map_err(serde::de::Error::custom)
209 }
210}
211
212pub struct Hasher {
214 algorithm: HashAlgorithm,
215 state: HasherState,
216}
217
218enum HasherState {
219 Sha256(Sha256),
220 Sha384(Sha384),
221 Sha512(Sha512),
222 Sha3_256(Sha3_256),
223 Sha3_512(Sha3_512),
224 Blake3(Box<blake3::Hasher>),
225}
226
227impl Hasher {
228 #[must_use]
230 pub fn new(algorithm: HashAlgorithm) -> Self {
231 let state = match algorithm {
232 HashAlgorithm::Sha256 => HasherState::Sha256(Sha256::new()),
233 HashAlgorithm::Sha384 => HasherState::Sha384(Sha384::new()),
234 HashAlgorithm::Sha512 => HasherState::Sha512(Sha512::new()),
235 HashAlgorithm::Sha3_256 => HasherState::Sha3_256(Sha3_256::new()),
236 HashAlgorithm::Sha3_512 => HasherState::Sha3_512(Sha3_512::new()),
237 HashAlgorithm::Blake3 => HasherState::Blake3(Box::new(blake3::Hasher::new())),
238 };
239 Self { algorithm, state }
240 }
241
242 #[must_use]
244 pub fn default_algorithm() -> Self {
245 Self::new(HashAlgorithm::default())
246 }
247
248 pub fn update(&mut self, data: &[u8]) {
250 match &mut self.state {
251 HasherState::Sha256(h) => h.update(data),
252 HasherState::Sha384(h) => h.update(data),
253 HasherState::Sha512(h) => h.update(data),
254 HasherState::Sha3_256(h) => h.update(data),
255 HasherState::Sha3_512(h) => h.update(data),
256 HasherState::Blake3(h) => {
257 h.update(data);
258 }
259 }
260 }
261
262 #[must_use]
264 pub fn finalize(self) -> DocumentId {
265 let digest = match self.state {
266 HasherState::Sha256(h) => h.finalize().to_vec(),
267 HasherState::Sha384(h) => h.finalize().to_vec(),
268 HasherState::Sha512(h) => h.finalize().to_vec(),
269 HasherState::Sha3_256(h) => h.finalize().to_vec(),
270 HasherState::Sha3_512(h) => h.finalize().to_vec(),
271 HasherState::Blake3(h) => h.finalize().as_bytes().to_vec(),
272 };
273 DocumentId::new(self.algorithm, digest)
274 }
275
276 #[must_use]
278 pub fn hash(algorithm: HashAlgorithm, data: &[u8]) -> DocumentId {
279 let mut hasher = Self::new(algorithm);
280 hasher.update(data);
281 hasher.finalize()
282 }
283
284 pub fn hash_reader<R: Read>(algorithm: HashAlgorithm, reader: &mut R) -> Result<DocumentId> {
290 let mut hasher = Self::new(algorithm);
291 let mut buffer = [0u8; 8192];
292 loop {
293 let n = reader.read(&mut buffer)?;
294 if n == 0 {
295 break;
296 }
297 hasher.update(&buffer[..n]);
298 }
299 Ok(hasher.finalize())
300 }
301}
302
303fn hex_encode(bytes: &[u8]) -> String {
305 use std::fmt::Write;
306 bytes
307 .iter()
308 .fold(String::with_capacity(bytes.len() * 2), |mut s, b| {
309 let _ = write!(s, "{b:02x}");
310 s
311 })
312}
313
314fn hex_decode(s: &str) -> std::result::Result<Vec<u8>, ()> {
316 if !s.len().is_multiple_of(2) {
317 return Err(());
318 }
319 (0..s.len())
320 .step_by(2)
321 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|_| ()))
322 .collect()
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_sha256_hash() {
331 let id = Hasher::hash(HashAlgorithm::Sha256, b"hello world");
332 assert_eq!(id.algorithm(), HashAlgorithm::Sha256);
333 assert_eq!(
334 id.hex_digest(),
335 "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
336 );
337 }
338
339 #[test]
340 fn test_document_id_parsing() {
341 let id_str = "sha256:b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
342 let id: DocumentId = id_str.parse().unwrap();
343 assert_eq!(id.algorithm(), HashAlgorithm::Sha256);
344 assert_eq!(id.to_string(), id_str);
345 }
346
347 #[test]
348 fn test_pending_id() {
349 let id = DocumentId::pending();
350 assert!(id.is_pending());
351 assert_eq!(id.to_string(), "pending");
352
353 let parsed: DocumentId = "pending".parse().unwrap();
354 assert!(parsed.is_pending());
355 }
356
357 #[test]
358 fn test_invalid_hash_format() {
359 assert!("invalid".parse::<DocumentId>().is_err());
360 assert!("sha256:xyz".parse::<DocumentId>().is_err());
361 assert!("sha256:ab".parse::<DocumentId>().is_err()); }
363
364 #[test]
365 fn test_blake3_hash() {
366 let id = Hasher::hash(HashAlgorithm::Blake3, b"hello world");
367 assert_eq!(id.algorithm(), HashAlgorithm::Blake3);
368 assert_eq!(id.digest().len(), 32);
369 }
370
371 #[test]
372 fn test_streaming_hash() {
373 let mut hasher = Hasher::new(HashAlgorithm::Sha256);
374 hasher.update(b"hello ");
375 hasher.update(b"world");
376 let id = hasher.finalize();
377
378 let direct = Hasher::hash(HashAlgorithm::Sha256, b"hello world");
379 assert_eq!(id, direct);
380 }
381
382 #[test]
383 fn test_serialization() {
384 let id = Hasher::hash(HashAlgorithm::Sha256, b"test");
385 let json = serde_json::to_string(&id).unwrap();
386 let parsed: DocumentId = serde_json::from_str(&json).unwrap();
387 assert_eq!(id, parsed);
388 }
389}
390
391#[cfg(test)]
392mod proptests {
393 use super::*;
394 use proptest::prelude::*;
395
396 proptest! {
397 #[test]
399 fn hash_is_deterministic(data: Vec<u8>) {
400 let h1 = Hasher::hash(HashAlgorithm::Sha256, &data);
401 let h2 = Hasher::hash(HashAlgorithm::Sha256, &data);
402 prop_assert_eq!(h1, h2);
403 }
404
405 #[test]
407 fn document_id_roundtrip(data: Vec<u8>) {
408 let original = Hasher::hash(HashAlgorithm::Sha256, &data);
409 let serialized = original.to_string();
410 let parsed: DocumentId = serialized.parse().unwrap();
411 prop_assert_eq!(original, parsed);
412 }
413
414 #[test]
416 fn different_inputs_different_hashes(a: Vec<u8>, b: Vec<u8>) {
417 prop_assume!(a != b);
418 let h1 = Hasher::hash(HashAlgorithm::Sha256, &a);
419 let h2 = Hasher::hash(HashAlgorithm::Sha256, &b);
420 prop_assert_ne!(h1, h2);
421 }
422
423 #[test]
425 fn streaming_equals_oneshot(data: Vec<u8>) {
426 let oneshot = Hasher::hash(HashAlgorithm::Sha256, &data);
427
428 let mut streaming = Hasher::new(HashAlgorithm::Sha256);
429 streaming.update(&data);
430 let result = streaming.finalize();
431
432 prop_assert_eq!(oneshot, result);
433 }
434
435 #[test]
437 fn hex_roundtrip(data: Vec<u8>) {
438 let encoded = hex_encode(&data);
439 let decoded = hex_decode(&encoded).unwrap();
440 prop_assert_eq!(data, decoded);
441 }
442
443 #[test]
445 fn json_roundtrip(data: Vec<u8>) {
446 let id = Hasher::hash(HashAlgorithm::Sha256, &data);
447 let json = serde_json::to_string(&id).unwrap();
448 let parsed: DocumentId = serde_json::from_str(&json).unwrap();
449 prop_assert_eq!(id, parsed);
450 }
451
452 #[test]
454 fn blake3_deterministic(data: Vec<u8>) {
455 let h1 = Hasher::hash(HashAlgorithm::Blake3, &data);
456 let h2 = Hasher::hash(HashAlgorithm::Blake3, &data);
457 prop_assert_eq!(h1, h2);
458 }
459 }
460}