1use crate::model::ModelId;
6use crate::recipe::RecipeId;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(tag = "type", rename_all = "snake_case")]
12pub enum ModelLineageEdge {
13 FineTuned {
15 parent: ModelId,
17 recipe: RecipeId,
19 },
20 Distilled {
22 teacher: ModelId,
24 temperature: f32,
26 },
27 Merged {
29 sources: Vec<ModelId>,
31 weights: Vec<f32>,
33 },
34 Quantized {
36 source: ModelId,
38 quantization: QuantizationType,
40 },
41 Pruned {
43 source: ModelId,
45 sparsity: f32,
47 },
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
52#[serde(rename_all = "lowercase")]
53pub enum QuantizationType {
54 Int8,
56 Int4,
58 Fp16,
60 Bf16,
62 Dynamic,
64}
65
66impl std::fmt::Display for QuantizationType {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 let s = match self {
69 Self::Int8 => "int8",
70 Self::Int4 => "int4",
71 Self::Fp16 => "fp16",
72 Self::Bf16 => "bf16",
73 Self::Dynamic => "dynamic",
74 };
75 write!(f, "{s}")
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct LineageNode {
82 pub model_id: ModelId,
84 pub model_name: String,
86 pub model_version: String,
88}
89
90#[derive(Debug, Clone, Default, Serialize, Deserialize)]
92pub struct LineageGraph {
93 pub nodes: Vec<LineageNode>,
95 pub edges: Vec<LineageEdgeRecord>,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct LineageEdgeRecord {
102 pub from_idx: usize,
104 pub to_idx: usize,
106 pub edge: ModelLineageEdge,
108}
109
110impl LineageGraph {
111 #[must_use]
113 pub fn new() -> Self {
114 Self::default()
115 }
116
117 pub fn add_node(&mut self, node: LineageNode) -> usize {
119 let idx = self.nodes.len();
120 self.nodes.push(node);
121 idx
122 }
123
124 pub fn add_edge(&mut self, from_idx: usize, to_idx: usize, edge: ModelLineageEdge) {
126 self.edges.push(LineageEdgeRecord { from_idx, to_idx, edge });
127 }
128
129 #[must_use]
131 pub fn node_count(&self) -> usize {
132 self.nodes.len()
133 }
134
135 #[must_use]
137 pub fn edge_count(&self) -> usize {
138 self.edges.len()
139 }
140
141 #[must_use]
143 pub fn ancestors(&self, node_idx: usize) -> Vec<usize> {
144 self.edges.iter().filter(|e| e.to_idx == node_idx).map(|e| e.from_idx).collect()
145 }
146
147 #[must_use]
149 pub fn descendants(&self, node_idx: usize) -> Vec<usize> {
150 self.edges.iter().filter(|e| e.from_idx == node_idx).map(|e| e.to_idx).collect()
151 }
152
153 #[must_use]
155 pub fn find_node(&self, model_id: &ModelId) -> Option<usize> {
156 self.nodes.iter().position(|n| &n.model_id == model_id)
157 }
158
159 #[must_use]
163 pub fn all_ancestors(&self, node_idx: usize) -> Vec<usize> {
164 let mut visited = std::collections::HashSet::new();
165 let mut result = Vec::new();
166 self.collect_ancestors(node_idx, &mut visited, &mut result);
167 result
168 }
169
170 fn collect_ancestors(
171 &self,
172 node_idx: usize,
173 visited: &mut std::collections::HashSet<usize>,
174 result: &mut Vec<usize>,
175 ) {
176 for parent_idx in self.ancestors(node_idx) {
177 if visited.insert(parent_idx) {
178 result.push(parent_idx);
179 self.collect_ancestors(parent_idx, visited, result);
180 }
181 }
182 }
183
184 #[must_use]
188 pub fn all_descendants(&self, node_idx: usize) -> Vec<usize> {
189 let mut visited = std::collections::HashSet::new();
190 let mut result = Vec::new();
191 self.collect_descendants(node_idx, &mut visited, &mut result);
192 result
193 }
194
195 fn collect_descendants(
196 &self,
197 node_idx: usize,
198 visited: &mut std::collections::HashSet<usize>,
199 result: &mut Vec<usize>,
200 ) {
201 for child_idx in self.descendants(node_idx) {
202 if visited.insert(child_idx) {
203 result.push(child_idx);
204 self.collect_descendants(child_idx, visited, result);
205 }
206 }
207 }
208
209 #[must_use]
211 pub fn root_nodes(&self) -> Vec<usize> {
212 (0..self.nodes.len()).filter(|&idx| self.ancestors(idx).is_empty()).collect()
213 }
214
215 #[must_use]
217 pub fn leaf_nodes(&self) -> Vec<usize> {
218 (0..self.nodes.len()).filter(|&idx| self.descendants(idx).is_empty()).collect()
219 }
220
221 #[must_use]
225 pub fn path_between(&self, from_idx: usize, to_idx: usize) -> Option<Vec<usize>> {
226 use std::collections::{HashMap, VecDeque};
227
228 if from_idx == to_idx {
229 return Some(vec![from_idx]);
230 }
231
232 let mut queue = VecDeque::new();
233 let mut parent_map: HashMap<usize, usize> = HashMap::new();
234
235 queue.push_back(from_idx);
236
237 while let Some(current) = queue.pop_front() {
238 for child_idx in self.descendants(current) {
239 if !parent_map.contains_key(&child_idx) {
240 parent_map.insert(child_idx, current);
241 if child_idx == to_idx {
242 let mut path = vec![to_idx];
244 let mut node = to_idx;
245 while let Some(&parent) = parent_map.get(&node) {
246 path.push(parent);
247 node = parent;
248 }
249 path.reverse();
250 return Some(path);
251 }
252 queue.push_back(child_idx);
253 }
254 }
255 }
256
257 None
258 }
259
260 #[must_use]
265 pub fn topological_sort(&self) -> Option<Vec<usize>> {
266 use std::collections::HashMap;
267
268 let n = self.nodes.len();
269 if n == 0 {
270 return Some(Vec::new());
271 }
272
273 let mut in_degree: HashMap<usize, usize> = (0..n).map(|i| (i, 0)).collect();
275 for edge in &self.edges {
276 *in_degree.entry(edge.to_idx).or_insert(0) += 1;
277 }
278
279 let mut queue: Vec<usize> = in_degree
281 .iter()
282 .filter_map(|(&node, °ree)| if degree == 0 { Some(node) } else { None })
283 .collect();
284
285 let mut result = Vec::with_capacity(n);
286
287 while let Some(node) = queue.pop() {
288 result.push(node);
289
290 for child in self.descendants(node) {
291 if let Some(degree) = in_degree.get_mut(&child) {
292 *degree -= 1;
293 if *degree == 0 {
294 queue.push(child);
295 }
296 }
297 }
298 }
299
300 if result.len() == n {
302 Some(result)
303 } else {
304 None
305 }
306 }
307
308 #[must_use]
310 pub fn depth(&self, node_idx: usize) -> usize {
311 let ancestors = self.ancestors(node_idx);
312 if ancestors.is_empty() {
313 0
314 } else {
315 ancestors.iter().map(|&a| self.depth(a) + 1).max().unwrap_or(0)
316 }
317 }
318
319 #[must_use]
321 pub fn edges_between(&self, from_idx: usize, to_idx: usize) -> Vec<&LineageEdgeRecord> {
322 self.edges.iter().filter(|e| e.from_idx == from_idx && e.to_idx == to_idx).collect()
323 }
324
325 #[must_use]
327 pub fn is_dag(&self) -> bool {
328 self.topological_sort().is_some()
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_quantization_type_display() {
338 assert_eq!(QuantizationType::Int8.to_string(), "int8");
339 assert_eq!(QuantizationType::Fp16.to_string(), "fp16");
340 }
341
342 #[test]
343 fn test_lineage_graph_basic() {
344 let mut graph = LineageGraph::new();
345
346 let base_id = ModelId::new();
347 let finetuned_id = ModelId::new();
348
349 let base_idx = graph.add_node(LineageNode {
350 model_id: base_id.clone(),
351 model_name: "base-model".to_string(),
352 model_version: "1.0.0".to_string(),
353 });
354
355 let finetuned_idx = graph.add_node(LineageNode {
356 model_id: finetuned_id.clone(),
357 model_name: "finetuned-model".to_string(),
358 model_version: "1.0.0".to_string(),
359 });
360
361 graph.add_edge(
362 base_idx,
363 finetuned_idx,
364 ModelLineageEdge::FineTuned { parent: base_id.clone(), recipe: RecipeId::new() },
365 );
366
367 assert_eq!(graph.node_count(), 2);
368 assert_eq!(graph.edge_count(), 1);
369 assert_eq!(graph.ancestors(finetuned_idx), vec![base_idx]);
370 assert_eq!(graph.descendants(base_idx), vec![finetuned_idx]);
371 }
372
373 #[test]
374 fn test_lineage_graph_find_node() {
375 let mut graph = LineageGraph::new();
376 let model_id = ModelId::new();
377
378 graph.add_node(LineageNode {
379 model_id: model_id.clone(),
380 model_name: "test-model".to_string(),
381 model_version: "1.0.0".to_string(),
382 });
383
384 assert_eq!(graph.find_node(&model_id), Some(0));
385 assert_eq!(graph.find_node(&ModelId::new()), None);
386 }
387
388 #[test]
389 fn test_lineage_edge_serialization() {
390 let edge = ModelLineageEdge::Quantized {
391 source: ModelId::new(),
392 quantization: QuantizationType::Int8,
393 };
394
395 let json = serde_json::to_string(&edge).unwrap();
396 assert!(json.contains("quantized"));
397 assert!(json.contains("int8"));
398
399 let deserialized: ModelLineageEdge = serde_json::from_str(&json).unwrap();
400 if let ModelLineageEdge::Quantized { quantization, .. } = deserialized {
401 assert_eq!(quantization, QuantizationType::Int8);
402 } else {
403 panic!("Wrong variant");
404 }
405 }
406
407 #[test]
408 fn test_merged_lineage() {
409 let sources = vec![ModelId::new(), ModelId::new(), ModelId::new()];
410 let weights = vec![0.5, 0.3, 0.2];
411
412 let edge = ModelLineageEdge::Merged { sources: sources.clone(), weights: weights.clone() };
413
414 let json = serde_json::to_string(&edge).unwrap();
415 let deserialized: ModelLineageEdge = serde_json::from_str(&json).unwrap();
416
417 if let ModelLineageEdge::Merged { sources: s, weights: w } = deserialized {
418 assert_eq!(s.len(), 3);
419 assert_eq!(w.len(), 3);
420 } else {
421 panic!("Wrong variant");
422 }
423 }
424
425 fn build_chain_graph() -> (LineageGraph, Vec<ModelId>) {
430 let mut graph = LineageGraph::new();
432 let ids: Vec<ModelId> = (0..4).map(|_| ModelId::new()).collect();
433
434 for (i, id) in ids.iter().enumerate() {
435 graph.add_node(LineageNode {
436 model_id: id.clone(),
437 model_name: format!("model-{i}"),
438 model_version: "1.0.0".to_string(),
439 });
440 }
441
442 for (i, id) in ids.iter().enumerate().take(3) {
443 graph.add_edge(
444 i,
445 i + 1,
446 ModelLineageEdge::FineTuned { parent: id.clone(), recipe: RecipeId::new() },
447 );
448 }
449
450 (graph, ids)
451 }
452
453 fn build_diamond_graph() -> (LineageGraph, Vec<ModelId>) {
454 let mut graph = LineageGraph::new();
461 let ids: Vec<ModelId> = (0..4).map(|_| ModelId::new()).collect();
462
463 let names = ["A", "B", "C", "D"];
464 for (i, (id, name)) in ids.iter().zip(names.iter()).enumerate() {
465 graph.add_node(LineageNode {
466 model_id: id.clone(),
467 model_name: (*name).to_string(),
468 model_version: format!("1.{i}.0"),
469 });
470 }
471
472 graph.add_edge(
474 0,
475 1,
476 ModelLineageEdge::FineTuned { parent: ids[0].clone(), recipe: RecipeId::new() },
477 );
478 graph.add_edge(
480 0,
481 2,
482 ModelLineageEdge::Quantized {
483 source: ids[0].clone(),
484 quantization: QuantizationType::Int8,
485 },
486 );
487 graph.add_edge(
489 1,
490 3,
491 ModelLineageEdge::FineTuned { parent: ids[1].clone(), recipe: RecipeId::new() },
492 );
493 graph.add_edge(
495 2,
496 3,
497 ModelLineageEdge::Merged {
498 sources: vec![ids[1].clone(), ids[2].clone()],
499 weights: vec![0.5, 0.5],
500 },
501 );
502
503 (graph, ids)
504 }
505
506 #[test]
507 fn test_all_ancestors_chain() {
508 let (graph, _) = build_chain_graph();
509
510 let ancestors = graph.all_ancestors(3);
512 assert_eq!(ancestors.len(), 3);
513 assert!(ancestors.contains(&0));
514 assert!(ancestors.contains(&1));
515 assert!(ancestors.contains(&2));
516
517 assert!(graph.all_ancestors(0).is_empty());
519
520 let ancestors = graph.all_ancestors(1);
522 assert_eq!(ancestors.len(), 1);
523 assert!(ancestors.contains(&0));
524 }
525
526 #[test]
527 fn test_all_descendants_chain() {
528 let (graph, _) = build_chain_graph();
529
530 let descendants = graph.all_descendants(0);
532 assert_eq!(descendants.len(), 3);
533 assert!(descendants.contains(&1));
534 assert!(descendants.contains(&2));
535 assert!(descendants.contains(&3));
536
537 assert!(graph.all_descendants(3).is_empty());
539 }
540
541 #[test]
542 fn test_all_ancestors_diamond() {
543 let (graph, _) = build_diamond_graph();
544
545 let ancestors = graph.all_ancestors(3);
547 assert_eq!(ancestors.len(), 3);
548 assert!(ancestors.contains(&0));
549 assert!(ancestors.contains(&1));
550 assert!(ancestors.contains(&2));
551 }
552
553 #[test]
554 fn test_root_nodes() {
555 let (chain, _) = build_chain_graph();
556 assert_eq!(chain.root_nodes(), vec![0]);
557
558 let (diamond, _) = build_diamond_graph();
559 assert_eq!(diamond.root_nodes(), vec![0]);
560 }
561
562 #[test]
563 fn test_leaf_nodes() {
564 let (chain, _) = build_chain_graph();
565 assert_eq!(chain.leaf_nodes(), vec![3]);
566
567 let (diamond, _) = build_diamond_graph();
568 assert_eq!(diamond.leaf_nodes(), vec![3]);
569 }
570
571 #[test]
572 fn test_path_between() {
573 let (graph, _) = build_chain_graph();
574
575 let path = graph.path_between(0, 3).unwrap();
577 assert_eq!(path, vec![0, 1, 2, 3]);
578
579 let path = graph.path_between(1, 3).unwrap();
581 assert_eq!(path, vec![1, 2, 3]);
582
583 let path = graph.path_between(2, 2).unwrap();
585 assert_eq!(path, vec![2]);
586
587 assert!(graph.path_between(3, 0).is_none());
589 }
590
591 #[test]
592 fn test_path_between_diamond() {
593 let (graph, _) = build_diamond_graph();
594
595 let path = graph.path_between(0, 3).unwrap();
597 assert!(path.len() == 3); assert_eq!(path[0], 0);
599 assert_eq!(*path.last().unwrap(), 3);
600 }
601
602 #[test]
603 fn test_topological_sort() {
604 let (graph, _) = build_chain_graph();
605 let sorted = graph.topological_sort().unwrap();
606
607 let pos_a = sorted.iter().position(|&x| x == 0).unwrap();
609 let pos_b = sorted.iter().position(|&x| x == 1).unwrap();
610 let pos_c = sorted.iter().position(|&x| x == 2).unwrap();
611 let pos_d = sorted.iter().position(|&x| x == 3).unwrap();
612
613 assert!(pos_a < pos_b);
614 assert!(pos_b < pos_c);
615 assert!(pos_c < pos_d);
616 }
617
618 #[test]
619 fn test_topological_sort_diamond() {
620 let (graph, _) = build_diamond_graph();
621 let sorted = graph.topological_sort().unwrap();
622
623 let pos_a = sorted.iter().position(|&x| x == 0).unwrap();
624 let pos_b = sorted.iter().position(|&x| x == 1).unwrap();
625 let pos_c = sorted.iter().position(|&x| x == 2).unwrap();
626 let pos_d = sorted.iter().position(|&x| x == 3).unwrap();
627
628 assert!(pos_a < pos_b);
630 assert!(pos_a < pos_c);
631 assert!(pos_b < pos_d);
633 assert!(pos_c < pos_d);
634 }
635
636 #[test]
637 fn test_topological_sort_empty() {
638 let graph = LineageGraph::new();
639 assert_eq!(graph.topological_sort(), Some(vec![]));
640 }
641
642 #[test]
643 fn test_depth() {
644 let (graph, _) = build_chain_graph();
645
646 assert_eq!(graph.depth(0), 0); assert_eq!(graph.depth(1), 1); assert_eq!(graph.depth(2), 2); assert_eq!(graph.depth(3), 3); }
651
652 #[test]
653 fn test_depth_diamond() {
654 let (graph, _) = build_diamond_graph();
655
656 assert_eq!(graph.depth(0), 0); assert_eq!(graph.depth(1), 1); assert_eq!(graph.depth(2), 1); assert_eq!(graph.depth(3), 2); }
661
662 #[test]
663 fn test_edges_between() {
664 let (graph, ids) = build_diamond_graph();
665
666 let edges = graph.edges_between(0, 1);
668 assert_eq!(edges.len(), 1);
669 assert!(matches!(edges[0].edge, ModelLineageEdge::FineTuned { .. }));
670
671 let edges = graph.edges_between(0, 2);
673 assert_eq!(edges.len(), 1);
674 assert!(matches!(edges[0].edge, ModelLineageEdge::Quantized { .. }));
675
676 assert!(graph.edges_between(1, 2).is_empty());
678
679 assert_eq!(graph.edges_between(1, 3).len(), 1);
681 assert_eq!(graph.edges_between(2, 3).len(), 1);
682
683 let _ = ids; }
685
686 #[test]
687 fn test_is_dag() {
688 let (graph, _) = build_chain_graph();
689 assert!(graph.is_dag());
690
691 let (graph, _) = build_diamond_graph();
692 assert!(graph.is_dag());
693
694 let empty = LineageGraph::new();
696 assert!(empty.is_dag());
697 }
698
699 #[test]
700 fn test_lineage_edge_pruned() {
701 let edge = ModelLineageEdge::Pruned { source: ModelId::new(), sparsity: 0.5 };
702
703 let json = serde_json::to_string(&edge).unwrap();
704 assert!(json.contains("pruned"));
705 assert!(json.contains("0.5"));
706
707 let deserialized: ModelLineageEdge = serde_json::from_str(&json).unwrap();
708 if let ModelLineageEdge::Pruned { sparsity, .. } = deserialized {
709 assert!((sparsity - 0.5).abs() < f32::EPSILON);
710 } else {
711 panic!("Wrong variant");
712 }
713 }
714
715 #[test]
716 fn test_lineage_edge_distilled() {
717 let edge = ModelLineageEdge::Distilled { teacher: ModelId::new(), temperature: 2.0 };
718
719 let json = serde_json::to_string(&edge).unwrap();
720 assert!(json.contains("distilled"));
721 assert!(json.contains("2.0") || json.contains("2"));
722
723 let deserialized: ModelLineageEdge = serde_json::from_str(&json).unwrap();
724 if let ModelLineageEdge::Distilled { temperature, .. } = deserialized {
725 assert!((temperature - 2.0).abs() < f32::EPSILON);
726 } else {
727 panic!("Wrong variant");
728 }
729 }
730
731 #[test]
732 fn test_all_quantization_types() {
733 let types = [
734 QuantizationType::Int8,
735 QuantizationType::Int4,
736 QuantizationType::Fp16,
737 QuantizationType::Bf16,
738 QuantizationType::Dynamic,
739 ];
740
741 for qt in types {
742 let edge = ModelLineageEdge::Quantized { source: ModelId::new(), quantization: qt };
743
744 let json = serde_json::to_string(&edge).unwrap();
745 let _: ModelLineageEdge = serde_json::from_str(&json).unwrap();
746 }
747 }
748}