1use std::{
5 fmt::Display,
6 path::{Path, PathBuf},
7 str::FromStr,
8};
9
10use either::Either;
11use serde::{
12 Deserialize, Deserializer, Serialize, Serializer,
13 de::{self, Visitor},
14 ser::SerializeStruct as _,
15};
16use url::Url;
17
18#[derive(Clone, Debug)]
19pub struct CheckedFile {
20 path: Either<PathBuf, Url>,
22
23 checksum: Checksum,
25}
26
27#[derive(Debug, Clone, Eq, PartialEq)]
28pub struct Checksum {
29 hash: String,
31
32 algorithm: CryptographicHashMethods,
34}
35
36#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
37pub enum CryptographicHashMethods {
38 #[serde(rename = "blake3")]
39 BLAKE3,
40}
41
42impl CheckedFile {
43 pub fn from_disk<P: Into<PathBuf>>(filepath: P) -> anyhow::Result<Self> {
44 let path: PathBuf = filepath.into();
45 if !path.exists() {
46 anyhow::bail!("File not found: {}", path.display());
47 }
48 if !path.is_file() {
49 anyhow::bail!("Not a file: {}", path.display());
50 }
51 let hash = b3sum(&path)?;
52
53 Ok(CheckedFile {
54 path: Either::Left(path),
55 checksum: Checksum::blake3(hash),
56 })
57 }
58
59 pub fn move_to_url(&mut self, u: url::Url) {
62 self.path = Either::Right(u);
63 }
64
65 pub fn move_to_disk<P: Into<PathBuf>>(&mut self, p: P) {
68 self.path = Either::Left(p.into());
69 }
70
71 pub fn path(&self) -> Option<&Path> {
72 match self.path.as_ref() {
73 Either::Left(p) => Some(p),
74 Either::Right(_) => None,
75 }
76 }
77
78 pub fn url(&self) -> Option<&Url> {
79 match self.path.as_ref() {
80 Either::Left(_) => None,
81 Either::Right(u) => Some(u),
82 }
83 }
84
85 pub fn is_nats_url(&self) -> bool {
86 matches!(self.path.as_ref(), Either::Right(u) if u.scheme() == "nats")
87 }
88
89 pub fn checksum(&self) -> &Checksum {
90 &self.checksum
91 }
92
93 pub fn checksum_matches<P: AsRef<Path> + std::fmt::Debug>(&self, disk_file: P) -> bool {
95 match b3sum(&disk_file) {
96 Ok(h) => Checksum::blake3(h) == self.checksum,
97 Err(error) => {
98 tracing::error!(disk_file = %disk_file.as_ref().display(), checked_file = self.to_string(), %error, "Checksum does not match");
99 false
100 }
101 }
102 }
103}
104
105impl Display for CheckedFile {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 let p = match &self.path {
108 Either::Left(local) => local.display().to_string(),
109 Either::Right(url) => url.to_string(),
110 };
111 write!(f, "({p}, {})", self.checksum)
112 }
113}
114
115impl Serialize for CheckedFile {
116 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
117 where
118 S: Serializer,
119 {
120 let mut cf = serializer.serialize_struct("CheckedFile", 2)?;
121 match &self.path {
122 Either::Left(path) => cf.serialize_field("path", &path)?,
123 Either::Right(url) => cf.serialize_field("path", &url)?,
124 };
125 cf.serialize_field("checksum", &self.checksum)?;
126 cf.end()
127 }
128}
129
130#[derive(Deserialize)]
132struct WireCheckedFile {
133 path: String,
134 checksum: Checksum,
135}
136
137impl From<WireCheckedFile> for CheckedFile {
139 fn from(temp: WireCheckedFile) -> Self {
140 match Url::parse(&temp.path) {
142 Ok(url) => CheckedFile {
143 path: Either::Right(url),
144 checksum: temp.checksum,
145 },
146 Err(_) => CheckedFile {
147 path: Either::Left(PathBuf::from(temp.path)),
148 checksum: temp.checksum,
149 },
150 }
151 }
152}
153
154impl<'de> Deserialize<'de> for CheckedFile {
156 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
157 where
158 D: Deserializer<'de>,
159 {
160 let temp = WireCheckedFile::deserialize(deserializer)?;
162 Ok(CheckedFile::from(temp))
163 }
164}
165
166fn b3sum<T: AsRef<Path> + std::fmt::Debug>(path: T) -> anyhow::Result<String> {
167 let path = path.as_ref();
168 let metadata = std::fs::metadata(path)?;
169 let filesize = metadata.len();
170 let mut hasher = blake3::Hasher::new();
171
172 if filesize > 128_000 {
173 hasher.update_mmap_rayon(path)?;
175 } else {
176 hasher.update_mmap(path)?;
178 }
179
180 let hash = hasher.finalize();
181 Ok(hash.to_string())
182}
183
184impl Checksum {
185 pub fn blake3(hash: impl Into<String>) -> Self {
186 Self::new(hash, CryptographicHashMethods::BLAKE3)
187 }
188
189 pub fn new(hash: impl Into<String>, algorithm: CryptographicHashMethods) -> Self {
190 Self {
191 hash: hash.into(),
192 algorithm,
193 }
194 }
195}
196
197impl Serialize for Checksum {
198 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
199 where
200 S: Serializer,
201 {
202 let serialized_str = format!("{}:{}", self.algorithm, self.hash);
203 serializer.serialize_str(&serialized_str)
204 }
205}
206
207impl<'de> Deserialize<'de> for Checksum {
208 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
209 where
210 D: Deserializer<'de>,
211 {
212 struct ChecksumVisitor;
213
214 impl Visitor<'_> for ChecksumVisitor {
215 type Value = Checksum;
216
217 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
218 formatter.write_str("a string in the format `{algo}:{hash}`")
219 }
220
221 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
222 where
223 E: de::Error,
224 {
225 let parts: Vec<&str> = value.split(':').collect();
226 if parts.len() != 2 {
227 return Err(de::Error::invalid_value(de::Unexpected::Str(value), &self));
228 }
229
230 let algorithm = parts[0].parse().map_err(|_| {
231 de::Error::invalid_value(de::Unexpected::Str(parts[0]), &"invalid algorithm")
232 })?;
233
234 Ok(Checksum::new(parts[1], algorithm))
235 }
236 }
237
238 deserializer.deserialize_str(ChecksumVisitor)
239 }
240}
241
242impl TryFrom<&str> for Checksum {
243 type Error = anyhow::Error;
244
245 fn try_from(value: &str) -> Result<Self, Self::Error> {
246 let parts: Vec<&str> = value.split(':').collect();
247 if parts.len() != 2 {
248 anyhow::bail!("Invalid checksum format; expect `algo:hash`; got: {value}");
249 }
250
251 let algo = match parts[0] {
252 "blake3" => CryptographicHashMethods::BLAKE3,
253 _ => {
254 anyhow::bail!("Unsupported cryptographic hash method: {}", parts[0]);
255 }
256 };
257
258 Ok(Checksum::new(parts[1], algo))
259 }
260}
261
262impl FromStr for CryptographicHashMethods {
263 type Err = String;
264
265 fn from_str(s: &str) -> Result<Self, Self::Err> {
266 match s {
267 "blake3" => Ok(CryptographicHashMethods::BLAKE3),
268 _ => Err(format!("Unsupported algorithm: {}", s)),
269 }
270 }
271}
272
273impl Display for CryptographicHashMethods {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 match self {
276 CryptographicHashMethods::BLAKE3 => write!(f, "blake3"),
277 }
278 }
279}
280
281impl Display for Checksum {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 write!(f, "{}:{}", self.algorithm, self.hash)
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_serialization_blake3() {
293 let checksum = Checksum::blake3("a12c3d4");
294
295 let serialized = serde_json::to_string(&checksum).unwrap();
296 assert_eq!(serialized.trim(), "\"blake3:a12c3d4\"");
297 }
298
299 #[test]
300 fn test_deserialization_blake3() {
301 let s = "\"blake3:abcd1234\"";
302 let deserialized: Checksum = serde_json::from_str(s).unwrap();
303
304 assert_eq!(deserialized.algorithm, CryptographicHashMethods::BLAKE3);
305 assert_eq!(deserialized.hash, "abcd1234");
306 }
307
308 #[test]
309 fn test_deserialization_invalid_format() {
310 let s = "\"invalidformat\"";
311 let result: Result<Checksum, _> = serde_json::from_str(s);
312
313 assert!(result.is_err());
314
315 let s = "\"blake3:invalid:format\"";
316 let result: Result<Checksum, _> = serde_json::from_str(s);
317
318 assert!(result.is_err());
319 }
320
321 #[test]
322 fn test_checked_file_from_disk() {
323 let root = env!("CARGO_MANIFEST_DIR"); let full_path = format!("{root}/tests/data/sample-models/TinyLlama_v1.1/config.json");
325 let cf = CheckedFile::from_disk(full_path).unwrap();
326 let expected =
327 Checksum::blake3("62bc124be974d3a25db05bedc99422660c26715e5bbda0b37d14bd84a0c65ab2");
328 assert_eq!(expected, *cf.checksum());
329 }
330}