Skip to main content

pacha/lineage/
mod.rs

1//! Model lineage tracking.
2//!
3//! Tracks how models are derived from other models through various operations.
4
5use crate::model::ModelId;
6use crate::recipe::RecipeId;
7use serde::{Deserialize, Serialize};
8
9/// Types of model lineage relationships.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(tag = "type", rename_all = "snake_case")]
12pub enum ModelLineageEdge {
13    /// Model was fine-tuned from parent.
14    FineTuned {
15        /// Parent model ID.
16        parent: ModelId,
17        /// Recipe used for fine-tuning.
18        recipe: RecipeId,
19    },
20    /// Model was distilled from teacher.
21    Distilled {
22        /// Teacher model ID.
23        teacher: ModelId,
24        /// Distillation temperature.
25        temperature: f32,
26    },
27    /// Model was merged from multiple sources.
28    Merged {
29        /// Source model IDs.
30        sources: Vec<ModelId>,
31        /// Merge weights.
32        weights: Vec<f32>,
33    },
34    /// Model was quantized from source.
35    Quantized {
36        /// Source model ID.
37        source: ModelId,
38        /// Quantization type.
39        quantization: QuantizationType,
40    },
41    /// Model was pruned from source.
42    Pruned {
43        /// Source model ID.
44        source: ModelId,
45        /// Target sparsity (0.0 to 1.0).
46        sparsity: f32,
47    },
48}
49
50/// Types of quantization.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
52#[serde(rename_all = "lowercase")]
53pub enum QuantizationType {
54    /// 8-bit integer quantization.
55    Int8,
56    /// 4-bit integer quantization.
57    Int4,
58    /// 16-bit floating point.
59    Fp16,
60    /// Brain floating point 16.
61    Bf16,
62    /// Dynamic quantization.
63    Dynamic,
64}
65
66impl std::fmt::Display for QuantizationType {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        let s = match self {
69            Self::Int8 => "int8",
70            Self::Int4 => "int4",
71            Self::Fp16 => "fp16",
72            Self::Bf16 => "bf16",
73            Self::Dynamic => "dynamic",
74        };
75        write!(f, "{s}")
76    }
77}
78
79/// A node in the lineage graph.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct LineageNode {
82    /// Model ID.
83    pub model_id: ModelId,
84    /// Model name.
85    pub model_name: String,
86    /// Model version string.
87    pub model_version: String,
88}
89
90/// A lineage graph showing model derivation history.
91#[derive(Debug, Clone, Default, Serialize, Deserialize)]
92pub struct LineageGraph {
93    /// Nodes in the graph.
94    pub nodes: Vec<LineageNode>,
95    /// Edges representing derivation relationships.
96    pub edges: Vec<LineageEdgeRecord>,
97}
98
99/// A recorded edge in the lineage graph.
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct LineageEdgeRecord {
102    /// Source node index.
103    pub from_idx: usize,
104    /// Target node index.
105    pub to_idx: usize,
106    /// Edge type and metadata.
107    pub edge: ModelLineageEdge,
108}
109
110impl LineageGraph {
111    /// Create an empty lineage graph.
112    #[must_use]
113    pub fn new() -> Self {
114        Self::default()
115    }
116
117    /// Add a node to the graph.
118    pub fn add_node(&mut self, node: LineageNode) -> usize {
119        let idx = self.nodes.len();
120        self.nodes.push(node);
121        idx
122    }
123
124    /// Add an edge to the graph.
125    pub fn add_edge(&mut self, from_idx: usize, to_idx: usize, edge: ModelLineageEdge) {
126        self.edges.push(LineageEdgeRecord { from_idx, to_idx, edge });
127    }
128
129    /// Get the number of nodes.
130    #[must_use]
131    pub fn node_count(&self) -> usize {
132        self.nodes.len()
133    }
134
135    /// Get the number of edges.
136    #[must_use]
137    pub fn edge_count(&self) -> usize {
138        self.edges.len()
139    }
140
141    /// Get ancestors of a node (models it was derived from).
142    #[must_use]
143    pub fn ancestors(&self, node_idx: usize) -> Vec<usize> {
144        self.edges.iter().filter(|e| e.to_idx == node_idx).map(|e| e.from_idx).collect()
145    }
146
147    /// Get descendants of a node (models derived from it).
148    #[must_use]
149    pub fn descendants(&self, node_idx: usize) -> Vec<usize> {
150        self.edges.iter().filter(|e| e.from_idx == node_idx).map(|e| e.to_idx).collect()
151    }
152
153    /// Find node index by model ID.
154    #[must_use]
155    pub fn find_node(&self, model_id: &ModelId) -> Option<usize> {
156        self.nodes.iter().position(|n| &n.model_id == model_id)
157    }
158
159    /// Get all ancestors of a node (full transitive closure).
160    ///
161    /// Returns all nodes from which this model was derived, recursively.
162    #[must_use]
163    pub fn all_ancestors(&self, node_idx: usize) -> Vec<usize> {
164        let mut visited = std::collections::HashSet::new();
165        let mut result = Vec::new();
166        self.collect_ancestors(node_idx, &mut visited, &mut result);
167        result
168    }
169
170    fn collect_ancestors(
171        &self,
172        node_idx: usize,
173        visited: &mut std::collections::HashSet<usize>,
174        result: &mut Vec<usize>,
175    ) {
176        for parent_idx in self.ancestors(node_idx) {
177            if visited.insert(parent_idx) {
178                result.push(parent_idx);
179                self.collect_ancestors(parent_idx, visited, result);
180            }
181        }
182    }
183
184    /// Get all descendants of a node (full transitive closure).
185    ///
186    /// Returns all nodes derived from this model, recursively.
187    #[must_use]
188    pub fn all_descendants(&self, node_idx: usize) -> Vec<usize> {
189        let mut visited = std::collections::HashSet::new();
190        let mut result = Vec::new();
191        self.collect_descendants(node_idx, &mut visited, &mut result);
192        result
193    }
194
195    fn collect_descendants(
196        &self,
197        node_idx: usize,
198        visited: &mut std::collections::HashSet<usize>,
199        result: &mut Vec<usize>,
200    ) {
201        for child_idx in self.descendants(node_idx) {
202            if visited.insert(child_idx) {
203                result.push(child_idx);
204                self.collect_descendants(child_idx, visited, result);
205            }
206        }
207    }
208
209    /// Get root models (models with no parents).
210    #[must_use]
211    pub fn root_nodes(&self) -> Vec<usize> {
212        (0..self.nodes.len()).filter(|&idx| self.ancestors(idx).is_empty()).collect()
213    }
214
215    /// Get leaf models (models with no children).
216    #[must_use]
217    pub fn leaf_nodes(&self) -> Vec<usize> {
218        (0..self.nodes.len()).filter(|&idx| self.descendants(idx).is_empty()).collect()
219    }
220
221    /// Find path between two nodes (returns node indices from source to target).
222    ///
223    /// Uses BFS to find shortest path. Returns `None` if no path exists.
224    #[must_use]
225    pub fn path_between(&self, from_idx: usize, to_idx: usize) -> Option<Vec<usize>> {
226        use std::collections::{HashMap, VecDeque};
227
228        if from_idx == to_idx {
229            return Some(vec![from_idx]);
230        }
231
232        let mut queue = VecDeque::new();
233        let mut parent_map: HashMap<usize, usize> = HashMap::new();
234
235        queue.push_back(from_idx);
236
237        while let Some(current) = queue.pop_front() {
238            for child_idx in self.descendants(current) {
239                if !parent_map.contains_key(&child_idx) {
240                    parent_map.insert(child_idx, current);
241                    if child_idx == to_idx {
242                        // Reconstruct path
243                        let mut path = vec![to_idx];
244                        let mut node = to_idx;
245                        while let Some(&parent) = parent_map.get(&node) {
246                            path.push(parent);
247                            node = parent;
248                        }
249                        path.reverse();
250                        return Some(path);
251                    }
252                    queue.push_back(child_idx);
253                }
254            }
255        }
256
257        None
258    }
259
260    /// Perform topological sort of the graph.
261    ///
262    /// Returns nodes in order such that parents come before children.
263    /// Returns `None` if the graph has a cycle.
264    #[must_use]
265    pub fn topological_sort(&self) -> Option<Vec<usize>> {
266        use std::collections::HashMap;
267
268        let n = self.nodes.len();
269        if n == 0 {
270            return Some(Vec::new());
271        }
272
273        // Calculate in-degree for each node
274        let mut in_degree: HashMap<usize, usize> = (0..n).map(|i| (i, 0)).collect();
275        for edge in &self.edges {
276            *in_degree.entry(edge.to_idx).or_insert(0) += 1;
277        }
278
279        // Start with nodes that have no incoming edges
280        let mut queue: Vec<usize> = in_degree
281            .iter()
282            .filter_map(|(&node, &degree)| if degree == 0 { Some(node) } else { None })
283            .collect();
284
285        let mut result = Vec::with_capacity(n);
286
287        while let Some(node) = queue.pop() {
288            result.push(node);
289
290            for child in self.descendants(node) {
291                if let Some(degree) = in_degree.get_mut(&child) {
292                    *degree -= 1;
293                    if *degree == 0 {
294                        queue.push(child);
295                    }
296                }
297            }
298        }
299
300        // If we didn't visit all nodes, there's a cycle
301        if result.len() == n {
302            Some(result)
303        } else {
304            None
305        }
306    }
307
308    /// Get depth of a node (longest path from any root).
309    #[must_use]
310    pub fn depth(&self, node_idx: usize) -> usize {
311        let ancestors = self.ancestors(node_idx);
312        if ancestors.is_empty() {
313            0
314        } else {
315            ancestors.iter().map(|&a| self.depth(a) + 1).max().unwrap_or(0)
316        }
317    }
318
319    /// Get the edges connecting two specific nodes.
320    #[must_use]
321    pub fn edges_between(&self, from_idx: usize, to_idx: usize) -> Vec<&LineageEdgeRecord> {
322        self.edges.iter().filter(|e| e.from_idx == from_idx && e.to_idx == to_idx).collect()
323    }
324
325    /// Check if the graph is a DAG (directed acyclic graph).
326    #[must_use]
327    pub fn is_dag(&self) -> bool {
328        self.topological_sort().is_some()
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_quantization_type_display() {
338        assert_eq!(QuantizationType::Int8.to_string(), "int8");
339        assert_eq!(QuantizationType::Fp16.to_string(), "fp16");
340    }
341
342    #[test]
343    fn test_lineage_graph_basic() {
344        let mut graph = LineageGraph::new();
345
346        let base_id = ModelId::new();
347        let finetuned_id = ModelId::new();
348
349        let base_idx = graph.add_node(LineageNode {
350            model_id: base_id.clone(),
351            model_name: "base-model".to_string(),
352            model_version: "1.0.0".to_string(),
353        });
354
355        let finetuned_idx = graph.add_node(LineageNode {
356            model_id: finetuned_id.clone(),
357            model_name: "finetuned-model".to_string(),
358            model_version: "1.0.0".to_string(),
359        });
360
361        graph.add_edge(
362            base_idx,
363            finetuned_idx,
364            ModelLineageEdge::FineTuned { parent: base_id.clone(), recipe: RecipeId::new() },
365        );
366
367        assert_eq!(graph.node_count(), 2);
368        assert_eq!(graph.edge_count(), 1);
369        assert_eq!(graph.ancestors(finetuned_idx), vec![base_idx]);
370        assert_eq!(graph.descendants(base_idx), vec![finetuned_idx]);
371    }
372
373    #[test]
374    fn test_lineage_graph_find_node() {
375        let mut graph = LineageGraph::new();
376        let model_id = ModelId::new();
377
378        graph.add_node(LineageNode {
379            model_id: model_id.clone(),
380            model_name: "test-model".to_string(),
381            model_version: "1.0.0".to_string(),
382        });
383
384        assert_eq!(graph.find_node(&model_id), Some(0));
385        assert_eq!(graph.find_node(&ModelId::new()), None);
386    }
387
388    #[test]
389    fn test_lineage_edge_serialization() {
390        let edge = ModelLineageEdge::Quantized {
391            source: ModelId::new(),
392            quantization: QuantizationType::Int8,
393        };
394
395        let json = serde_json::to_string(&edge).unwrap();
396        assert!(json.contains("quantized"));
397        assert!(json.contains("int8"));
398
399        let deserialized: ModelLineageEdge = serde_json::from_str(&json).unwrap();
400        if let ModelLineageEdge::Quantized { quantization, .. } = deserialized {
401            assert_eq!(quantization, QuantizationType::Int8);
402        } else {
403            panic!("Wrong variant");
404        }
405    }
406
407    #[test]
408    fn test_merged_lineage() {
409        let sources = vec![ModelId::new(), ModelId::new(), ModelId::new()];
410        let weights = vec![0.5, 0.3, 0.2];
411
412        let edge = ModelLineageEdge::Merged { sources: sources.clone(), weights: weights.clone() };
413
414        let json = serde_json::to_string(&edge).unwrap();
415        let deserialized: ModelLineageEdge = serde_json::from_str(&json).unwrap();
416
417        if let ModelLineageEdge::Merged { sources: s, weights: w } = deserialized {
418            assert_eq!(s.len(), 3);
419            assert_eq!(w.len(), 3);
420        } else {
421            panic!("Wrong variant");
422        }
423    }
424
425    // -------------------------------------------------------------------------
426    // Full Traversal Tests
427    // -------------------------------------------------------------------------
428
429    fn build_chain_graph() -> (LineageGraph, Vec<ModelId>) {
430        // Creates: A -> B -> C -> D
431        let mut graph = LineageGraph::new();
432        let ids: Vec<ModelId> = (0..4).map(|_| ModelId::new()).collect();
433
434        for (i, id) in ids.iter().enumerate() {
435            graph.add_node(LineageNode {
436                model_id: id.clone(),
437                model_name: format!("model-{i}"),
438                model_version: "1.0.0".to_string(),
439            });
440        }
441
442        for (i, id) in ids.iter().enumerate().take(3) {
443            graph.add_edge(
444                i,
445                i + 1,
446                ModelLineageEdge::FineTuned { parent: id.clone(), recipe: RecipeId::new() },
447            );
448        }
449
450        (graph, ids)
451    }
452
453    fn build_diamond_graph() -> (LineageGraph, Vec<ModelId>) {
454        // Creates:
455        //     A
456        //    / \
457        //   B   C
458        //    \ /
459        //     D
460        let mut graph = LineageGraph::new();
461        let ids: Vec<ModelId> = (0..4).map(|_| ModelId::new()).collect();
462
463        let names = ["A", "B", "C", "D"];
464        for (i, (id, name)) in ids.iter().zip(names.iter()).enumerate() {
465            graph.add_node(LineageNode {
466                model_id: id.clone(),
467                model_name: (*name).to_string(),
468                model_version: format!("1.{i}.0"),
469            });
470        }
471
472        // A -> B
473        graph.add_edge(
474            0,
475            1,
476            ModelLineageEdge::FineTuned { parent: ids[0].clone(), recipe: RecipeId::new() },
477        );
478        // A -> C
479        graph.add_edge(
480            0,
481            2,
482            ModelLineageEdge::Quantized {
483                source: ids[0].clone(),
484                quantization: QuantizationType::Int8,
485            },
486        );
487        // B -> D
488        graph.add_edge(
489            1,
490            3,
491            ModelLineageEdge::FineTuned { parent: ids[1].clone(), recipe: RecipeId::new() },
492        );
493        // C -> D
494        graph.add_edge(
495            2,
496            3,
497            ModelLineageEdge::Merged {
498                sources: vec![ids[1].clone(), ids[2].clone()],
499                weights: vec![0.5, 0.5],
500            },
501        );
502
503        (graph, ids)
504    }
505
506    #[test]
507    fn test_all_ancestors_chain() {
508        let (graph, _) = build_chain_graph();
509
510        // D (idx 3) should have ancestors A, B, C
511        let ancestors = graph.all_ancestors(3);
512        assert_eq!(ancestors.len(), 3);
513        assert!(ancestors.contains(&0));
514        assert!(ancestors.contains(&1));
515        assert!(ancestors.contains(&2));
516
517        // A (idx 0) should have no ancestors
518        assert!(graph.all_ancestors(0).is_empty());
519
520        // B (idx 1) should have only A
521        let ancestors = graph.all_ancestors(1);
522        assert_eq!(ancestors.len(), 1);
523        assert!(ancestors.contains(&0));
524    }
525
526    #[test]
527    fn test_all_descendants_chain() {
528        let (graph, _) = build_chain_graph();
529
530        // A (idx 0) should have descendants B, C, D
531        let descendants = graph.all_descendants(0);
532        assert_eq!(descendants.len(), 3);
533        assert!(descendants.contains(&1));
534        assert!(descendants.contains(&2));
535        assert!(descendants.contains(&3));
536
537        // D (idx 3) should have no descendants
538        assert!(graph.all_descendants(3).is_empty());
539    }
540
541    #[test]
542    fn test_all_ancestors_diamond() {
543        let (graph, _) = build_diamond_graph();
544
545        // D has ancestors A, B, C
546        let ancestors = graph.all_ancestors(3);
547        assert_eq!(ancestors.len(), 3);
548        assert!(ancestors.contains(&0));
549        assert!(ancestors.contains(&1));
550        assert!(ancestors.contains(&2));
551    }
552
553    #[test]
554    fn test_root_nodes() {
555        let (chain, _) = build_chain_graph();
556        assert_eq!(chain.root_nodes(), vec![0]);
557
558        let (diamond, _) = build_diamond_graph();
559        assert_eq!(diamond.root_nodes(), vec![0]);
560    }
561
562    #[test]
563    fn test_leaf_nodes() {
564        let (chain, _) = build_chain_graph();
565        assert_eq!(chain.leaf_nodes(), vec![3]);
566
567        let (diamond, _) = build_diamond_graph();
568        assert_eq!(diamond.leaf_nodes(), vec![3]);
569    }
570
571    #[test]
572    fn test_path_between() {
573        let (graph, _) = build_chain_graph();
574
575        // Path from A to D
576        let path = graph.path_between(0, 3).unwrap();
577        assert_eq!(path, vec![0, 1, 2, 3]);
578
579        // Path from B to D
580        let path = graph.path_between(1, 3).unwrap();
581        assert_eq!(path, vec![1, 2, 3]);
582
583        // Same node
584        let path = graph.path_between(2, 2).unwrap();
585        assert_eq!(path, vec![2]);
586
587        // No path (wrong direction)
588        assert!(graph.path_between(3, 0).is_none());
589    }
590
591    #[test]
592    fn test_path_between_diamond() {
593        let (graph, _) = build_diamond_graph();
594
595        // Path from A to D (could go through B or C)
596        let path = graph.path_between(0, 3).unwrap();
597        assert!(path.len() == 3); // A -> B/C -> D
598        assert_eq!(path[0], 0);
599        assert_eq!(*path.last().unwrap(), 3);
600    }
601
602    #[test]
603    fn test_topological_sort() {
604        let (graph, _) = build_chain_graph();
605        let sorted = graph.topological_sort().unwrap();
606
607        // A should come before B, B before C, C before D
608        let pos_a = sorted.iter().position(|&x| x == 0).unwrap();
609        let pos_b = sorted.iter().position(|&x| x == 1).unwrap();
610        let pos_c = sorted.iter().position(|&x| x == 2).unwrap();
611        let pos_d = sorted.iter().position(|&x| x == 3).unwrap();
612
613        assert!(pos_a < pos_b);
614        assert!(pos_b < pos_c);
615        assert!(pos_c < pos_d);
616    }
617
618    #[test]
619    fn test_topological_sort_diamond() {
620        let (graph, _) = build_diamond_graph();
621        let sorted = graph.topological_sort().unwrap();
622
623        let pos_a = sorted.iter().position(|&x| x == 0).unwrap();
624        let pos_b = sorted.iter().position(|&x| x == 1).unwrap();
625        let pos_c = sorted.iter().position(|&x| x == 2).unwrap();
626        let pos_d = sorted.iter().position(|&x| x == 3).unwrap();
627
628        // A should come before B and C
629        assert!(pos_a < pos_b);
630        assert!(pos_a < pos_c);
631        // B and C should come before D
632        assert!(pos_b < pos_d);
633        assert!(pos_c < pos_d);
634    }
635
636    #[test]
637    fn test_topological_sort_empty() {
638        let graph = LineageGraph::new();
639        assert_eq!(graph.topological_sort(), Some(vec![]));
640    }
641
642    #[test]
643    fn test_depth() {
644        let (graph, _) = build_chain_graph();
645
646        assert_eq!(graph.depth(0), 0); // A is root
647        assert_eq!(graph.depth(1), 1); // B
648        assert_eq!(graph.depth(2), 2); // C
649        assert_eq!(graph.depth(3), 3); // D
650    }
651
652    #[test]
653    fn test_depth_diamond() {
654        let (graph, _) = build_diamond_graph();
655
656        assert_eq!(graph.depth(0), 0); // A is root
657        assert_eq!(graph.depth(1), 1); // B
658        assert_eq!(graph.depth(2), 1); // C
659        assert_eq!(graph.depth(3), 2); // D (longest path is A->B->D or A->C->D)
660    }
661
662    #[test]
663    fn test_edges_between() {
664        let (graph, ids) = build_diamond_graph();
665
666        // A -> B has one edge
667        let edges = graph.edges_between(0, 1);
668        assert_eq!(edges.len(), 1);
669        assert!(matches!(edges[0].edge, ModelLineageEdge::FineTuned { .. }));
670
671        // A -> C has one edge
672        let edges = graph.edges_between(0, 2);
673        assert_eq!(edges.len(), 1);
674        assert!(matches!(edges[0].edge, ModelLineageEdge::Quantized { .. }));
675
676        // No edge between B and C
677        assert!(graph.edges_between(1, 2).is_empty());
678
679        // D has edges from both B and C
680        assert_eq!(graph.edges_between(1, 3).len(), 1);
681        assert_eq!(graph.edges_between(2, 3).len(), 1);
682
683        let _ = ids; // suppress unused warning
684    }
685
686    #[test]
687    fn test_is_dag() {
688        let (graph, _) = build_chain_graph();
689        assert!(graph.is_dag());
690
691        let (graph, _) = build_diamond_graph();
692        assert!(graph.is_dag());
693
694        // Empty graph is a DAG
695        let empty = LineageGraph::new();
696        assert!(empty.is_dag());
697    }
698
699    #[test]
700    fn test_lineage_edge_pruned() {
701        let edge = ModelLineageEdge::Pruned { source: ModelId::new(), sparsity: 0.5 };
702
703        let json = serde_json::to_string(&edge).unwrap();
704        assert!(json.contains("pruned"));
705        assert!(json.contains("0.5"));
706
707        let deserialized: ModelLineageEdge = serde_json::from_str(&json).unwrap();
708        if let ModelLineageEdge::Pruned { sparsity, .. } = deserialized {
709            assert!((sparsity - 0.5).abs() < f32::EPSILON);
710        } else {
711            panic!("Wrong variant");
712        }
713    }
714
715    #[test]
716    fn test_lineage_edge_distilled() {
717        let edge = ModelLineageEdge::Distilled { teacher: ModelId::new(), temperature: 2.0 };
718
719        let json = serde_json::to_string(&edge).unwrap();
720        assert!(json.contains("distilled"));
721        assert!(json.contains("2.0") || json.contains("2"));
722
723        let deserialized: ModelLineageEdge = serde_json::from_str(&json).unwrap();
724        if let ModelLineageEdge::Distilled { temperature, .. } = deserialized {
725            assert!((temperature - 2.0).abs() < f32::EPSILON);
726        } else {
727            panic!("Wrong variant");
728        }
729    }
730
731    #[test]
732    fn test_all_quantization_types() {
733        let types = [
734            QuantizationType::Int8,
735            QuantizationType::Int4,
736            QuantizationType::Fp16,
737            QuantizationType::Bf16,
738            QuantizationType::Dynamic,
739        ];
740
741        for qt in types {
742            let edge = ModelLineageEdge::Quantized { source: ModelId::new(), quantization: qt };
743
744            let json = serde_json::to_string(&edge).unwrap();
745            let _: ModelLineageEdge = serde_json::from_str(&json).unwrap();
746        }
747    }
748}