oci_unpack/
digest.rs

1use std::{
2    fmt,
3    io::{self, Read},
4};
5
6use sha2::Digest as _;
7
8/// Algorithm to compute the hash value.
9///
10/// See [`Digest`] for an example.
11#[derive(Clone, Copy, Debug, PartialEq)]
12#[non_exhaustive]
13pub enum DigestAlgorithm {
14    SHA256,
15    SHA512,
16}
17
18/// A digest to validate a blob.
19///
20/// It contains the algorithm (like `SHA256`) and its expected value as
21/// a hexadecimal string.
22///
23/// # Examples
24///
25/// ```
26/// # use oci_unpack::*;
27/// const DIGEST: &str = "123456789012345678901234567890123456789012345678901234567890ABCD";
28///
29/// let digest = Digest::try_from(format!("sha256:{}", DIGEST)).unwrap();
30/// assert_eq!(digest.algorithm(), DigestAlgorithm::SHA256);
31/// assert_eq!(digest.hash_value(), DIGEST);
32/// ```
33#[derive(Clone, Debug, PartialEq, serde::Deserialize)]
34#[serde(try_from = "String")]
35pub struct Digest {
36    hash: String,
37    algorithm: DigestAlgorithm,
38}
39
40/// Errors from the digest parser.
41#[derive(thiserror::Error, Debug)]
42pub enum DigestError {
43    #[error("Invalid digest algorithm.")]
44    InvalidAlgorithm,
45
46    #[error("Invalid digest value.")]
47    InvalidValue,
48}
49
50impl Digest {
51    /// Original string to build this instance (`algorithm:hash_value`).
52    pub fn source(&self) -> &str {
53        &self.hash
54    }
55
56    pub fn hash_value(&self) -> &str {
57        self.hash
58            .split_once(':')
59            .map(|(_, h)| h)
60            .unwrap_or_default()
61    }
62
63    pub fn algorithm(&self) -> DigestAlgorithm {
64        self.algorithm
65    }
66
67    /// Return a `Read` instance to compute its digest.
68    ///
69    /// When all data from `reader` is consumed, it verifies that the
70    /// computed digest is the expected one. If not, it returns an
71    /// [`InvalidData`](::std::io::ErrorKind::InvalidData)
72    /// error.
73    pub fn wrap_reader<R: Read>(&self, reader: R) -> impl Read {
74        let hasher: Box<dyn digest::DynDigest> = match self.algorithm {
75            DigestAlgorithm::SHA256 => Box::new(sha2::Sha256::new()),
76            DigestAlgorithm::SHA512 => Box::new(sha2::Sha512::new()),
77        };
78
79        DigestReader {
80            hasher,
81            expected: self.hash_value().to_owned(),
82            reader,
83        }
84    }
85}
86
87impl TryFrom<String> for Digest {
88    type Error = DigestError;
89
90    fn try_from(hash: String) -> Result<Self, Self::Error> {
91        let (algorithm, value, expected_size) = {
92            if let Some(h) = hash.strip_prefix("sha256:") {
93                (DigestAlgorithm::SHA256, h, 256 / 8 * 2)
94            } else if let Some(h) = hash.strip_prefix("sha512:") {
95                (DigestAlgorithm::SHA512, h, 512 / 8 * 2)
96            } else {
97                return Err(DigestError::InvalidAlgorithm);
98            }
99        };
100
101        // Validate that the hash value is a string with the expected length,
102        // and it only contains hexadecimal digits.
103        if value.len() == expected_size && value.chars().all(|c| c.is_ascii_hexdigit()) {
104            Ok(Digest { hash, algorithm })
105        } else {
106            Err(DigestError::InvalidValue)
107        }
108    }
109}
110
111struct DigestReader<R> {
112    hasher: Box<dyn digest::DynDigest>,
113    expected: String,
114    reader: R,
115}
116
117impl<R: Read> Read for DigestReader<R> {
118    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
119        let buf_len = buf.len();
120        let n = self.reader.read(buf)?;
121
122        if n == 0 && buf_len > 0 {
123            // On EOF, compare the computed digest with the expected one.
124            return self.check_hash();
125        }
126
127        self.hasher.update(&buf[..n]);
128
129        Ok(n)
130    }
131}
132
133impl<R> DigestReader<R> {
134    fn check_hash(&mut self) -> io::Result<usize> {
135        const MAX_DIGEST_SIZE: usize = 512 / 8;
136
137        debug_assert_eq!(self.hasher.output_size() * 2, self.expected.len());
138
139        let mut buffer = [0u8; MAX_DIGEST_SIZE];
140        let out = &mut buffer[..self.hasher.output_size()];
141
142        self.hasher
143            .finalize_into_reset(out)
144            .map_err(io::Error::other)?;
145
146        let mut expected = self.expected.as_str();
147        for hash_byte in out.iter() {
148            match expected
149                .split_at_checked(2)
150                .map(|(b, t)| (u8::from_str_radix(b, 16), t))
151            {
152                Some((Ok(byte), t)) if byte == *hash_byte => expected = t,
153
154                _ => {
155                    return Err(io::Error::new(
156                        io::ErrorKind::InvalidData,
157                        format!(
158                            "Invalid digest. Expected {}, got {}.",
159                            self.expected,
160                            HexString(out)
161                        ),
162                    ))
163                }
164            }
165        }
166
167        Ok(0)
168    }
169}
170
171/// Encode a byte buffer as hex string.
172pub(crate) struct HexString<T>(pub T);
173
174impl<T: AsRef<[u8]>> fmt::Display for HexString<T> {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        self.0
177            .as_ref()
178            .iter()
179            .try_for_each(|byte| write!(f, "{:02x}", byte))
180    }
181}
182
183#[test]
184fn encode_hex_bytes() {
185    assert_eq!(HexString(b"\x01\x20\xf0").to_string(), "0120f0");
186}
187
188#[test]
189fn reject_invalid_digest() {
190    use std::io::Cursor;
191
192    /// Digest for `abc`
193    const DIGEST: &str = "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad";
194
195    let digest = Digest::try_from(format!("sha256:{DIGEST}")).unwrap();
196    let mut output = Vec::new();
197
198    // Accept a valid digest.
199    digest
200        .wrap_reader(Cursor::new("abc"))
201        .read_to_end(&mut output)
202        .unwrap();
203
204    assert_eq!(output, b"abc");
205
206    // Reject an invalid digest.
207    output.clear();
208    let err = digest
209        .wrap_reader(Cursor::new("abcx"))
210        .read_to_end(&mut output)
211        .unwrap_err();
212
213    assert_eq!(err.kind(), io::ErrorKind::InvalidData);
214
215    let msg = err.into_inner().unwrap().to_string().to_lowercase();
216    assert!(msg.contains(DIGEST));
217    assert!(msg.contains("7571ce1f8e21c6b13dd7ec2c5ec7c9e4dd9852e209869511853f2f1f74b17927"));
218}