1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelMetadata {
12 pub model_id: String,
14 pub version: String,
16 pub accuracy: f64,
18 pub created_at: u64,
20 pub config_hash: String,
22 pub tags: HashMap<String, String>,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum ChangeType {
29 AddData,
31 Hyperparams,
33 Architecture,
35 Retrain,
37 FineTune,
39 Merge,
41}
42
43impl ChangeType {
44 pub fn as_str(&self) -> &'static str {
45 match self {
46 ChangeType::AddData => "add_data",
47 ChangeType::Hyperparams => "hyperparams",
48 ChangeType::Architecture => "architecture",
49 ChangeType::Retrain => "retrain",
50 ChangeType::FineTune => "fine_tune",
51 ChangeType::Merge => "merge",
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Derivation {
59 pub parent_id: String,
61 pub child_id: String,
63 pub change_type: ChangeType,
65 pub description: String,
67}
68
69#[derive(Debug, Default, Serialize, Deserialize)]
71pub struct ModelLineage {
72 models: HashMap<String, ModelMetadata>,
74 derivations: Vec<Derivation>,
76}
77
78impl ModelLineage {
79 pub fn new() -> Self {
81 Self::default()
82 }
83
84 pub fn add_model(&mut self, metadata: ModelMetadata) -> String {
86 let id = metadata.model_id.clone();
87 self.models.insert(id.clone(), metadata);
88 id
89 }
90
91 pub fn add_derivation(
93 &mut self,
94 parent_id: &str,
95 child_id: &str,
96 change_type: ChangeType,
97 description: &str,
98 ) {
99 self.derivations.push(Derivation {
100 parent_id: parent_id.to_string(),
101 child_id: child_id.to_string(),
102 change_type,
103 description: description.to_string(),
104 });
105 }
106
107 pub fn get_model(&self, id: &str) -> Option<&ModelMetadata> {
109 self.models.get(id)
110 }
111
112 pub fn all_models(&self) -> impl Iterator<Item = &ModelMetadata> {
114 self.models.values()
115 }
116
117 pub fn get_parent(&self, child_id: &str) -> Option<&ModelMetadata> {
119 self.derivations
120 .iter()
121 .find(|d| d.child_id == child_id)
122 .and_then(|d| self.models.get(&d.parent_id))
123 }
124
125 pub fn get_children(&self, parent_id: &str) -> Vec<&ModelMetadata> {
127 self.derivations
128 .iter()
129 .filter(|d| d.parent_id == parent_id)
130 .filter_map(|d| self.models.get(&d.child_id))
131 .collect()
132 }
133
134 pub fn compare(&self, a_id: &str, b_id: &str) -> Option<ModelComparison> {
136 let a = self.models.get(a_id)?;
137 let b = self.models.get(b_id)?;
138
139 Some(ModelComparison {
140 model_a: a_id.to_string(),
141 model_b: b_id.to_string(),
142 accuracy_delta: b.accuracy - a.accuracy,
143 is_improvement: b.accuracy > a.accuracy,
144 })
145 }
146
147 pub fn find_regression_source(&self, model_id: &str) -> Option<&Derivation> {
149 let model = self.models.get(model_id)?;
150
151 let derivation = self.derivations.iter().find(|d| d.child_id == model_id)?;
153 let parent = self.models.get(&derivation.parent_id)?;
154
155 if model.accuracy < parent.accuracy {
157 Some(derivation)
158 } else {
159 None
160 }
161 }
162
163 pub fn get_lineage_chain(&self, model_id: &str) -> Vec<String> {
165 let mut chain = vec![model_id.to_string()];
166 let mut current = model_id;
167
168 while let Some(derivation) = self.derivations.iter().find(|d| d.child_id == current) {
169 chain.push(derivation.parent_id.clone());
170 current = &derivation.parent_id;
171 }
172
173 chain.reverse();
174 chain
175 }
176
177 pub fn to_json(&self) -> Result<String, serde_json::Error> {
179 serde_json::to_string_pretty(self)
180 }
181
182 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
184 serde_json::from_str(json)
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct ModelComparison {
191 pub model_a: String,
192 pub model_b: String,
193 pub accuracy_delta: f64,
194 pub is_improvement: bool,
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 fn make_model(id: &str, version: &str, accuracy: f64) -> ModelMetadata {
202 ModelMetadata {
203 model_id: id.to_string(),
204 version: version.to_string(),
205 accuracy,
206 created_at: 0,
207 config_hash: String::new(),
208 tags: HashMap::new(),
209 }
210 }
211
212 #[test]
213 fn test_lineage_new() {
214 let lineage = ModelLineage::new();
215 assert_eq!(lineage.models.len(), 0);
216 }
217
218 #[test]
219 fn test_add_model() {
220 let mut lineage = ModelLineage::new();
221 let id = lineage.add_model(make_model("v1", "1.0.0", 0.85));
222 assert_eq!(id, "v1");
223 assert!(lineage.get_model("v1").is_some());
224 }
225
226 #[test]
227 fn test_add_derivation() {
228 let mut lineage = ModelLineage::new();
229 lineage.add_model(make_model("v1", "1.0.0", 0.85));
230 lineage.add_model(make_model("v2", "2.0.0", 0.87));
231 lineage.add_derivation("v1", "v2", ChangeType::AddData, "Added 1000 samples");
232
233 assert_eq!(lineage.derivations.len(), 1);
234 }
235
236 #[test]
237 fn test_get_parent() {
238 let mut lineage = ModelLineage::new();
239 lineage.add_model(make_model("v1", "1.0.0", 0.85));
240 lineage.add_model(make_model("v2", "2.0.0", 0.87));
241 lineage.add_derivation("v1", "v2", ChangeType::AddData, "More data");
242
243 let parent = lineage.get_parent("v2").expect("operation should succeed");
244 assert_eq!(parent.model_id, "v1");
245 }
246
247 #[test]
248 fn test_get_children() {
249 let mut lineage = ModelLineage::new();
250 lineage.add_model(make_model("v1", "1.0.0", 0.85));
251 lineage.add_model(make_model("v2a", "2.0.0", 0.87));
252 lineage.add_model(make_model("v2b", "2.1.0", 0.86));
253 lineage.add_derivation("v1", "v2a", ChangeType::AddData, "Branch A");
254 lineage.add_derivation("v1", "v2b", ChangeType::Hyperparams, "Branch B");
255
256 let children = lineage.get_children("v1");
257 assert_eq!(children.len(), 2);
258 }
259
260 #[test]
261 fn test_compare_improvement() {
262 let mut lineage = ModelLineage::new();
263 lineage.add_model(make_model("v1", "1.0.0", 0.85));
264 lineage.add_model(make_model("v2", "2.0.0", 0.87));
265
266 let cmp = lineage.compare("v1", "v2").expect("operation should succeed");
267 assert!(cmp.is_improvement);
268 assert!((cmp.accuracy_delta - 0.02).abs() < 1e-6);
269 }
270
271 #[test]
272 fn test_compare_regression() {
273 let mut lineage = ModelLineage::new();
274 lineage.add_model(make_model("v1", "1.0.0", 0.87));
275 lineage.add_model(make_model("v2", "2.0.0", 0.82));
276
277 let cmp = lineage.compare("v1", "v2").expect("operation should succeed");
278 assert!(!cmp.is_improvement);
279 }
280
281 #[test]
282 fn test_find_regression_source() {
283 let mut lineage = ModelLineage::new();
284 lineage.add_model(make_model("v1", "1.0.0", 0.87));
285 lineage.add_model(make_model("v2", "2.0.0", 0.82));
286 lineage.add_derivation("v1", "v2", ChangeType::Hyperparams, "Changed LR");
287
288 let source = lineage.find_regression_source("v2").expect("operation should succeed");
289 assert_eq!(source.change_type, ChangeType::Hyperparams);
290 }
291
292 #[test]
293 fn test_lineage_chain() {
294 let mut lineage = ModelLineage::new();
295 lineage.add_model(make_model("v1", "1.0.0", 0.80));
296 lineage.add_model(make_model("v2", "2.0.0", 0.85));
297 lineage.add_model(make_model("v3", "3.0.0", 0.87));
298 lineage.add_derivation("v1", "v2", ChangeType::AddData, "");
299 lineage.add_derivation("v2", "v3", ChangeType::FineTune, "");
300
301 let chain = lineage.get_lineage_chain("v3");
302 assert_eq!(chain, vec!["v1", "v2", "v3"]);
303 }
304
305 #[test]
306 fn test_json_roundtrip() {
307 let mut lineage = ModelLineage::new();
308 lineage.add_model(make_model("v1", "1.0.0", 0.85));
309
310 let json = lineage.to_json().expect("operation should succeed");
311 let loaded = ModelLineage::from_json(&json).expect("load should succeed");
312 assert!(loaded.get_model("v1").is_some());
313 }
314
315 #[test]
320 fn test_change_type_as_str() {
321 assert_eq!(ChangeType::AddData.as_str(), "add_data");
322 assert_eq!(ChangeType::Hyperparams.as_str(), "hyperparams");
323 assert_eq!(ChangeType::Architecture.as_str(), "architecture");
324 assert_eq!(ChangeType::Retrain.as_str(), "retrain");
325 assert_eq!(ChangeType::FineTune.as_str(), "fine_tune");
326 assert_eq!(ChangeType::Merge.as_str(), "merge");
327 }
328
329 #[test]
330 fn test_all_models() {
331 let mut lineage = ModelLineage::new();
332 lineage.add_model(make_model("v1", "1.0.0", 0.80));
333 lineage.add_model(make_model("v2", "2.0.0", 0.85));
334 lineage.add_model(make_model("v3", "3.0.0", 0.90));
335
336 let models: Vec<_> = lineage.all_models().collect();
337 assert_eq!(models.len(), 3);
338 }
339
340 #[test]
341 fn test_get_parent_no_parent() {
342 let mut lineage = ModelLineage::new();
343 lineage.add_model(make_model("v1", "1.0.0", 0.80));
344
345 assert!(lineage.get_parent("v1").is_none());
346 }
347
348 #[test]
349 fn test_find_regression_source_no_regression() {
350 let mut lineage = ModelLineage::new();
351 lineage.add_model(make_model("v1", "1.0.0", 0.80));
352 lineage.add_model(make_model("v2", "2.0.0", 0.85)); lineage.add_derivation("v1", "v2", ChangeType::AddData, "More data");
354
355 assert!(lineage.find_regression_source("v2").is_none());
357 }
358
359 #[test]
360 fn test_find_regression_source_nonexistent() {
361 let lineage = ModelLineage::new();
362 assert!(lineage.find_regression_source("v99").is_none());
363 }
364
365 #[test]
366 fn test_compare_nonexistent_models() {
367 let mut lineage = ModelLineage::new();
368 lineage.add_model(make_model("v1", "1.0.0", 0.80));
369
370 assert!(lineage.compare("v1", "v99").is_none());
371 assert!(lineage.compare("v99", "v1").is_none());
372 }
373
374 #[test]
375 fn test_get_children_no_children() {
376 let mut lineage = ModelLineage::new();
377 lineage.add_model(make_model("v1", "1.0.0", 0.80));
378
379 let children = lineage.get_children("v1");
380 assert!(children.is_empty());
381 }
382
383 #[test]
384 fn test_get_model_nonexistent() {
385 let lineage = ModelLineage::new();
386 assert!(lineage.get_model("v99").is_none());
387 }
388
389 #[test]
390 fn test_lineage_chain_single() {
391 let mut lineage = ModelLineage::new();
392 lineage.add_model(make_model("v1", "1.0.0", 0.80));
393
394 let chain = lineage.get_lineage_chain("v1");
395 assert_eq!(chain, vec!["v1"]);
396 }
397
398 #[test]
399 fn test_model_metadata_with_tags() {
400 let mut tags = HashMap::new();
401 tags.insert("env".to_string(), "production".to_string());
402 tags.insert("owner".to_string(), "team-ml".to_string());
403
404 let model = ModelMetadata {
405 model_id: "v1".to_string(),
406 version: "1.0.0".to_string(),
407 accuracy: 0.95,
408 created_at: 1700000000,
409 config_hash: "abc123".to_string(),
410 tags,
411 };
412
413 assert_eq!(model.tags.len(), 2);
414 assert_eq!(model.created_at, 1700000000);
415 }
416
417 #[test]
418 fn test_derivation_clone() {
419 let d = Derivation {
420 parent_id: "v1".to_string(),
421 child_id: "v2".to_string(),
422 change_type: ChangeType::Merge,
423 description: "merged models".to_string(),
424 };
425 let cloned = d.clone();
426 assert_eq!(d.parent_id, cloned.parent_id);
427 assert_eq!(d.change_type, cloned.change_type);
428 }
429
430 #[test]
431 fn test_model_comparison_clone() {
432 let cmp = ModelComparison {
433 model_a: "v1".to_string(),
434 model_b: "v2".to_string(),
435 accuracy_delta: 0.05,
436 is_improvement: true,
437 };
438 let cloned = cmp.clone();
439 assert_eq!(cmp.accuracy_delta, cloned.accuracy_delta);
440 }
441
442 #[test]
443 fn test_model_lineage_default() {
444 let lineage = ModelLineage::default();
445 assert!(lineage.models.is_empty());
446 assert!(lineage.derivations.is_empty());
447 }
448}