dynamo_llm/common/
checked_file.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{
5    fmt::Display,
6    path::{Path, PathBuf},
7    str::FromStr,
8};
9
10use either::Either;
11use serde::{
12    Deserialize, Deserializer, Serialize, Serializer,
13    de::{self, Visitor},
14    ser::SerializeStruct as _,
15};
16use url::Url;
17
18#[derive(Clone, Debug)]
19pub struct CheckedFile {
20    /// Either a path on local disk or a remote URL (usually nats object store)
21    path: Either<PathBuf, Url>,
22
23    /// Checksum of the contents of path
24    checksum: Checksum,
25}
26
27#[derive(Debug, Clone, Eq, PartialEq)]
28pub struct Checksum {
29    /// The checksum is a hex encoded string of the file's content
30    hash: String,
31
32    /// Checksum algorithm
33    algorithm: CryptographicHashMethods,
34}
35
36#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
37pub enum CryptographicHashMethods {
38    #[serde(rename = "blake3")]
39    BLAKE3,
40}
41
42impl CheckedFile {
43    pub fn from_disk<P: Into<PathBuf>>(filepath: P) -> anyhow::Result<Self> {
44        let path: PathBuf = filepath.into();
45        if !path.exists() {
46            anyhow::bail!("File not found: {}", path.display());
47        }
48        if !path.is_file() {
49            anyhow::bail!("Not a file: {}", path.display());
50        }
51        let hash = b3sum(&path)?;
52
53        Ok(CheckedFile {
54            path: Either::Left(path),
55            checksum: Checksum::blake3(hash),
56        })
57    }
58
59    /// Replace the local disk path with a remote URL.
60    /// Just updates the field, doesn't move any files.
61    pub fn move_to_url(&mut self, u: url::Url) {
62        self.path = Either::Right(u);
63    }
64
65    /// Replace a remove URL with local disk path.
66    /// Just updates the field, doesn't move any files.
67    pub fn move_to_disk<P: Into<PathBuf>>(&mut self, p: P) {
68        self.path = Either::Left(p.into());
69    }
70
71    pub fn path(&self) -> Option<&Path> {
72        match self.path.as_ref() {
73            Either::Left(p) => Some(p),
74            Either::Right(_) => None,
75        }
76    }
77
78    pub fn url(&self) -> Option<&Url> {
79        match self.path.as_ref() {
80            Either::Left(_) => None,
81            Either::Right(u) => Some(u),
82        }
83    }
84
85    pub fn is_nats_url(&self) -> bool {
86        matches!(self.path.as_ref(), Either::Right(u) if u.scheme() == "nats")
87    }
88
89    pub fn checksum(&self) -> &Checksum {
90        &self.checksum
91    }
92
93    /// Does the given file checksum to the same value as this CheckedFile?
94    pub fn checksum_matches<P: AsRef<Path> + std::fmt::Debug>(&self, disk_file: P) -> bool {
95        match b3sum(&disk_file) {
96            Ok(h) => Checksum::blake3(h) == self.checksum,
97            Err(error) => {
98                tracing::error!(disk_file = %disk_file.as_ref().display(), checked_file = self.to_string(), %error, "Checksum does not match");
99                false
100            }
101        }
102    }
103}
104
105impl Display for CheckedFile {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        let p = match &self.path {
108            Either::Left(local) => local.display().to_string(),
109            Either::Right(url) => url.to_string(),
110        };
111        write!(f, "({p}, {})", self.checksum)
112    }
113}
114
115impl Serialize for CheckedFile {
116    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
117    where
118        S: Serializer,
119    {
120        let mut cf = serializer.serialize_struct("CheckedFile", 2)?;
121        match &self.path {
122            Either::Left(path) => cf.serialize_field("path", &path)?,
123            Either::Right(url) => cf.serialize_field("path", &url)?,
124        };
125        cf.serialize_field("checksum", &self.checksum)?;
126        cf.end()
127    }
128}
129
130/// Internal type to simplify deserializing
131#[derive(Deserialize)]
132struct WireCheckedFile {
133    path: String,
134    checksum: Checksum,
135}
136
137// Convert from the temporary struct to CheckedFile with path type logic.
138impl From<WireCheckedFile> for CheckedFile {
139    fn from(temp: WireCheckedFile) -> Self {
140        // Try to parse as a URL; if successful, use Either::Right(Url), else use Either::Left(PathBuf).
141        match Url::parse(&temp.path) {
142            Ok(url) => CheckedFile {
143                path: Either::Right(url),
144                checksum: temp.checksum,
145            },
146            Err(_) => CheckedFile {
147                path: Either::Left(PathBuf::from(temp.path)),
148                checksum: temp.checksum,
149            },
150        }
151    }
152}
153
154// Implement Deserialize for CheckedFile using the temporary struct.
155impl<'de> Deserialize<'de> for CheckedFile {
156    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
157    where
158        D: Deserializer<'de>,
159    {
160        // Deserialize into WireCheckedFile, then convert to CheckedFile.
161        let temp = WireCheckedFile::deserialize(deserializer)?;
162        Ok(CheckedFile::from(temp))
163    }
164}
165
166fn b3sum<T: AsRef<Path> + std::fmt::Debug>(path: T) -> anyhow::Result<String> {
167    let path = path.as_ref();
168    let metadata = std::fs::metadata(path)?;
169    let filesize = metadata.len();
170    let mut hasher = blake3::Hasher::new();
171
172    if filesize > 128_000 {
173        // multithreaded. blake3 recommend this above 128 KiB.
174        hasher.update_mmap_rayon(path)?;
175    } else {
176        // Uses mmap above 16 KiB, normal load otherwise.
177        hasher.update_mmap(path)?;
178    }
179
180    let hash = hasher.finalize();
181    Ok(hash.to_string())
182}
183
184impl Checksum {
185    pub fn blake3(hash: impl Into<String>) -> Self {
186        Self::new(hash, CryptographicHashMethods::BLAKE3)
187    }
188
189    pub fn new(hash: impl Into<String>, algorithm: CryptographicHashMethods) -> Self {
190        Self {
191            hash: hash.into(),
192            algorithm,
193        }
194    }
195}
196
197impl Serialize for Checksum {
198    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
199    where
200        S: Serializer,
201    {
202        let serialized_str = format!("{}:{}", self.algorithm, self.hash);
203        serializer.serialize_str(&serialized_str)
204    }
205}
206
207impl<'de> Deserialize<'de> for Checksum {
208    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
209    where
210        D: Deserializer<'de>,
211    {
212        struct ChecksumVisitor;
213
214        impl Visitor<'_> for ChecksumVisitor {
215            type Value = Checksum;
216
217            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
218                formatter.write_str("a string in the format `{algo}:{hash}`")
219            }
220
221            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
222            where
223                E: de::Error,
224            {
225                let parts: Vec<&str> = value.split(':').collect();
226                if parts.len() != 2 {
227                    return Err(de::Error::invalid_value(de::Unexpected::Str(value), &self));
228                }
229
230                let algorithm = parts[0].parse().map_err(|_| {
231                    de::Error::invalid_value(de::Unexpected::Str(parts[0]), &"invalid algorithm")
232                })?;
233
234                Ok(Checksum::new(parts[1], algorithm))
235            }
236        }
237
238        deserializer.deserialize_str(ChecksumVisitor)
239    }
240}
241
242impl TryFrom<&str> for Checksum {
243    type Error = anyhow::Error;
244
245    fn try_from(value: &str) -> Result<Self, Self::Error> {
246        let parts: Vec<&str> = value.split(':').collect();
247        if parts.len() != 2 {
248            anyhow::bail!("Invalid checksum format; expect `algo:hash`; got: {value}");
249        }
250
251        let algo = match parts[0] {
252            "blake3" => CryptographicHashMethods::BLAKE3,
253            _ => {
254                anyhow::bail!("Unsupported cryptographic hash method: {}", parts[0]);
255            }
256        };
257
258        Ok(Checksum::new(parts[1], algo))
259    }
260}
261
262impl FromStr for CryptographicHashMethods {
263    type Err = String;
264
265    fn from_str(s: &str) -> Result<Self, Self::Err> {
266        match s {
267            "blake3" => Ok(CryptographicHashMethods::BLAKE3),
268            _ => Err(format!("Unsupported algorithm: {}", s)),
269        }
270    }
271}
272
273impl Display for CryptographicHashMethods {
274    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275        match self {
276            CryptographicHashMethods::BLAKE3 => write!(f, "blake3"),
277        }
278    }
279}
280
281impl Display for Checksum {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        write!(f, "{}:{}", self.algorithm, self.hash)
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_serialization_blake3() {
293        let checksum = Checksum::blake3("a12c3d4");
294
295        let serialized = serde_json::to_string(&checksum).unwrap();
296        assert_eq!(serialized.trim(), "\"blake3:a12c3d4\"");
297    }
298
299    #[test]
300    fn test_deserialization_blake3() {
301        let s = "\"blake3:abcd1234\"";
302        let deserialized: Checksum = serde_json::from_str(s).unwrap();
303
304        assert_eq!(deserialized.algorithm, CryptographicHashMethods::BLAKE3);
305        assert_eq!(deserialized.hash, "abcd1234");
306    }
307
308    #[test]
309    fn test_deserialization_invalid_format() {
310        let s = "\"invalidformat\"";
311        let result: Result<Checksum, _> = serde_json::from_str(s);
312
313        assert!(result.is_err());
314
315        let s = "\"blake3:invalid:format\"";
316        let result: Result<Checksum, _> = serde_json::from_str(s);
317
318        assert!(result.is_err());
319    }
320
321    #[test]
322    fn test_checked_file_from_disk() {
323        let root = env!("CARGO_MANIFEST_DIR"); // ${WORKSPACE}/lib/llm
324        let full_path = format!("{root}/tests/data/sample-models/TinyLlama_v1.1/config.json");
325        let cf = CheckedFile::from_disk(full_path).unwrap();
326        let expected =
327            Checksum::blake3("62bc124be974d3a25db05bedc99422660c26715e5bbda0b37d14bd84a0c65ab2");
328        assert_eq!(expected, *cf.checksum());
329    }
330}