Skip to main content

entrenar/monitor/
lineage.rs

1//! Model Lineage Tracking (ENT-046)
2//!
3//! Track model versions and training derivations.
4//! Toyota Way 改善 (Kaizen): Track improvement over time.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Model metadata
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelMetadata {
12    /// Model identifier
13    pub model_id: String,
14    /// Semantic version
15    pub version: String,
16    /// Validation accuracy
17    pub accuracy: f64,
18    /// Creation timestamp
19    pub created_at: u64,
20    /// Configuration hash
21    pub config_hash: String,
22    /// Additional tags
23    pub tags: HashMap<String, String>,
24}
25
26/// What changed between model versions
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum ChangeType {
29    /// More training data added
30    AddData,
31    /// Hyperparameters changed
32    Hyperparams,
33    /// Architecture modified
34    Architecture,
35    /// Different training run (same config)
36    Retrain,
37    /// Fine-tuning applied
38    FineTune,
39    /// Model merged
40    Merge,
41}
42
43impl ChangeType {
44    pub fn as_str(&self) -> &'static str {
45        match self {
46            ChangeType::AddData => "add_data",
47            ChangeType::Hyperparams => "hyperparams",
48            ChangeType::Architecture => "architecture",
49            ChangeType::Retrain => "retrain",
50            ChangeType::FineTune => "fine_tune",
51            ChangeType::Merge => "merge",
52        }
53    }
54}
55
56/// Edge in the lineage graph
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Derivation {
59    /// Parent model ID
60    pub parent_id: String,
61    /// Child model ID
62    pub child_id: String,
63    /// What changed
64    pub change_type: ChangeType,
65    /// Description of change
66    pub description: String,
67}
68
69/// Model lineage tracker
70#[derive(Debug, Default, Serialize, Deserialize)]
71pub struct ModelLineage {
72    /// All models by ID
73    models: HashMap<String, ModelMetadata>,
74    /// Derivation edges
75    derivations: Vec<Derivation>,
76}
77
78impl ModelLineage {
79    /// Create a new lineage tracker
80    pub fn new() -> Self {
81        Self::default()
82    }
83
84    /// Add a model to the lineage
85    pub fn add_model(&mut self, metadata: ModelMetadata) -> String {
86        let id = metadata.model_id.clone();
87        self.models.insert(id.clone(), metadata);
88        id
89    }
90
91    /// Add a derivation edge
92    pub fn add_derivation(
93        &mut self,
94        parent_id: &str,
95        child_id: &str,
96        change_type: ChangeType,
97        description: &str,
98    ) {
99        self.derivations.push(Derivation {
100            parent_id: parent_id.to_string(),
101            child_id: child_id.to_string(),
102            change_type,
103            description: description.to_string(),
104        });
105    }
106
107    /// Get a model by ID
108    pub fn get_model(&self, id: &str) -> Option<&ModelMetadata> {
109        self.models.get(id)
110    }
111
112    /// Get all models
113    pub fn all_models(&self) -> impl Iterator<Item = &ModelMetadata> {
114        self.models.values()
115    }
116
117    /// Get parent of a model
118    pub fn get_parent(&self, child_id: &str) -> Option<&ModelMetadata> {
119        self.derivations
120            .iter()
121            .find(|d| d.child_id == child_id)
122            .and_then(|d| self.models.get(&d.parent_id))
123    }
124
125    /// Get children of a model
126    pub fn get_children(&self, parent_id: &str) -> Vec<&ModelMetadata> {
127        self.derivations
128            .iter()
129            .filter(|d| d.parent_id == parent_id)
130            .filter_map(|d| self.models.get(&d.child_id))
131            .collect()
132    }
133
134    /// Compare two model versions
135    pub fn compare(&self, a_id: &str, b_id: &str) -> Option<ModelComparison> {
136        let a = self.models.get(a_id)?;
137        let b = self.models.get(b_id)?;
138
139        Some(ModelComparison {
140            model_a: a_id.to_string(),
141            model_b: b_id.to_string(),
142            accuracy_delta: b.accuracy - a.accuracy,
143            is_improvement: b.accuracy > a.accuracy,
144        })
145    }
146
147    /// Find what caused a regression
148    pub fn find_regression_source(&self, model_id: &str) -> Option<&Derivation> {
149        let model = self.models.get(model_id)?;
150
151        // Find parent
152        let derivation = self.derivations.iter().find(|d| d.child_id == model_id)?;
153        let parent = self.models.get(&derivation.parent_id)?;
154
155        // Check if this is a regression
156        if model.accuracy < parent.accuracy {
157            Some(derivation)
158        } else {
159            None
160        }
161    }
162
163    /// Get lineage chain from root to model
164    pub fn get_lineage_chain(&self, model_id: &str) -> Vec<String> {
165        let mut chain = vec![model_id.to_string()];
166        let mut current = model_id;
167
168        while let Some(derivation) = self.derivations.iter().find(|d| d.child_id == current) {
169            chain.push(derivation.parent_id.clone());
170            current = &derivation.parent_id;
171        }
172
173        chain.reverse();
174        chain
175    }
176
177    /// Export to JSON
178    pub fn to_json(&self) -> Result<String, serde_json::Error> {
179        serde_json::to_string_pretty(self)
180    }
181
182    /// Load from JSON
183    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
184        serde_json::from_str(json)
185    }
186}
187
188/// Comparison between two model versions
189#[derive(Debug, Clone)]
190pub struct ModelComparison {
191    pub model_a: String,
192    pub model_b: String,
193    pub accuracy_delta: f64,
194    pub is_improvement: bool,
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    fn make_model(id: &str, version: &str, accuracy: f64) -> ModelMetadata {
202        ModelMetadata {
203            model_id: id.to_string(),
204            version: version.to_string(),
205            accuracy,
206            created_at: 0,
207            config_hash: String::new(),
208            tags: HashMap::new(),
209        }
210    }
211
212    #[test]
213    fn test_lineage_new() {
214        let lineage = ModelLineage::new();
215        assert_eq!(lineage.models.len(), 0);
216    }
217
218    #[test]
219    fn test_add_model() {
220        let mut lineage = ModelLineage::new();
221        let id = lineage.add_model(make_model("v1", "1.0.0", 0.85));
222        assert_eq!(id, "v1");
223        assert!(lineage.get_model("v1").is_some());
224    }
225
226    #[test]
227    fn test_add_derivation() {
228        let mut lineage = ModelLineage::new();
229        lineage.add_model(make_model("v1", "1.0.0", 0.85));
230        lineage.add_model(make_model("v2", "2.0.0", 0.87));
231        lineage.add_derivation("v1", "v2", ChangeType::AddData, "Added 1000 samples");
232
233        assert_eq!(lineage.derivations.len(), 1);
234    }
235
236    #[test]
237    fn test_get_parent() {
238        let mut lineage = ModelLineage::new();
239        lineage.add_model(make_model("v1", "1.0.0", 0.85));
240        lineage.add_model(make_model("v2", "2.0.0", 0.87));
241        lineage.add_derivation("v1", "v2", ChangeType::AddData, "More data");
242
243        let parent = lineage.get_parent("v2").expect("operation should succeed");
244        assert_eq!(parent.model_id, "v1");
245    }
246
247    #[test]
248    fn test_get_children() {
249        let mut lineage = ModelLineage::new();
250        lineage.add_model(make_model("v1", "1.0.0", 0.85));
251        lineage.add_model(make_model("v2a", "2.0.0", 0.87));
252        lineage.add_model(make_model("v2b", "2.1.0", 0.86));
253        lineage.add_derivation("v1", "v2a", ChangeType::AddData, "Branch A");
254        lineage.add_derivation("v1", "v2b", ChangeType::Hyperparams, "Branch B");
255
256        let children = lineage.get_children("v1");
257        assert_eq!(children.len(), 2);
258    }
259
260    #[test]
261    fn test_compare_improvement() {
262        let mut lineage = ModelLineage::new();
263        lineage.add_model(make_model("v1", "1.0.0", 0.85));
264        lineage.add_model(make_model("v2", "2.0.0", 0.87));
265
266        let cmp = lineage.compare("v1", "v2").expect("operation should succeed");
267        assert!(cmp.is_improvement);
268        assert!((cmp.accuracy_delta - 0.02).abs() < 1e-6);
269    }
270
271    #[test]
272    fn test_compare_regression() {
273        let mut lineage = ModelLineage::new();
274        lineage.add_model(make_model("v1", "1.0.0", 0.87));
275        lineage.add_model(make_model("v2", "2.0.0", 0.82));
276
277        let cmp = lineage.compare("v1", "v2").expect("operation should succeed");
278        assert!(!cmp.is_improvement);
279    }
280
281    #[test]
282    fn test_find_regression_source() {
283        let mut lineage = ModelLineage::new();
284        lineage.add_model(make_model("v1", "1.0.0", 0.87));
285        lineage.add_model(make_model("v2", "2.0.0", 0.82));
286        lineage.add_derivation("v1", "v2", ChangeType::Hyperparams, "Changed LR");
287
288        let source = lineage.find_regression_source("v2").expect("operation should succeed");
289        assert_eq!(source.change_type, ChangeType::Hyperparams);
290    }
291
292    #[test]
293    fn test_lineage_chain() {
294        let mut lineage = ModelLineage::new();
295        lineage.add_model(make_model("v1", "1.0.0", 0.80));
296        lineage.add_model(make_model("v2", "2.0.0", 0.85));
297        lineage.add_model(make_model("v3", "3.0.0", 0.87));
298        lineage.add_derivation("v1", "v2", ChangeType::AddData, "");
299        lineage.add_derivation("v2", "v3", ChangeType::FineTune, "");
300
301        let chain = lineage.get_lineage_chain("v3");
302        assert_eq!(chain, vec!["v1", "v2", "v3"]);
303    }
304
305    #[test]
306    fn test_json_roundtrip() {
307        let mut lineage = ModelLineage::new();
308        lineage.add_model(make_model("v1", "1.0.0", 0.85));
309
310        let json = lineage.to_json().expect("operation should succeed");
311        let loaded = ModelLineage::from_json(&json).expect("load should succeed");
312        assert!(loaded.get_model("v1").is_some());
313    }
314
315    // =========================================================================
316    // Additional Coverage Tests
317    // =========================================================================
318
319    #[test]
320    fn test_change_type_as_str() {
321        assert_eq!(ChangeType::AddData.as_str(), "add_data");
322        assert_eq!(ChangeType::Hyperparams.as_str(), "hyperparams");
323        assert_eq!(ChangeType::Architecture.as_str(), "architecture");
324        assert_eq!(ChangeType::Retrain.as_str(), "retrain");
325        assert_eq!(ChangeType::FineTune.as_str(), "fine_tune");
326        assert_eq!(ChangeType::Merge.as_str(), "merge");
327    }
328
329    #[test]
330    fn test_all_models() {
331        let mut lineage = ModelLineage::new();
332        lineage.add_model(make_model("v1", "1.0.0", 0.80));
333        lineage.add_model(make_model("v2", "2.0.0", 0.85));
334        lineage.add_model(make_model("v3", "3.0.0", 0.90));
335
336        let models: Vec<_> = lineage.all_models().collect();
337        assert_eq!(models.len(), 3);
338    }
339
340    #[test]
341    fn test_get_parent_no_parent() {
342        let mut lineage = ModelLineage::new();
343        lineage.add_model(make_model("v1", "1.0.0", 0.80));
344
345        assert!(lineage.get_parent("v1").is_none());
346    }
347
348    #[test]
349    fn test_find_regression_source_no_regression() {
350        let mut lineage = ModelLineage::new();
351        lineage.add_model(make_model("v1", "1.0.0", 0.80));
352        lineage.add_model(make_model("v2", "2.0.0", 0.85)); // Improvement
353        lineage.add_derivation("v1", "v2", ChangeType::AddData, "More data");
354
355        // v2 is an improvement, so no regression
356        assert!(lineage.find_regression_source("v2").is_none());
357    }
358
359    #[test]
360    fn test_find_regression_source_nonexistent() {
361        let lineage = ModelLineage::new();
362        assert!(lineage.find_regression_source("v99").is_none());
363    }
364
365    #[test]
366    fn test_compare_nonexistent_models() {
367        let mut lineage = ModelLineage::new();
368        lineage.add_model(make_model("v1", "1.0.0", 0.80));
369
370        assert!(lineage.compare("v1", "v99").is_none());
371        assert!(lineage.compare("v99", "v1").is_none());
372    }
373
374    #[test]
375    fn test_get_children_no_children() {
376        let mut lineage = ModelLineage::new();
377        lineage.add_model(make_model("v1", "1.0.0", 0.80));
378
379        let children = lineage.get_children("v1");
380        assert!(children.is_empty());
381    }
382
383    #[test]
384    fn test_get_model_nonexistent() {
385        let lineage = ModelLineage::new();
386        assert!(lineage.get_model("v99").is_none());
387    }
388
389    #[test]
390    fn test_lineage_chain_single() {
391        let mut lineage = ModelLineage::new();
392        lineage.add_model(make_model("v1", "1.0.0", 0.80));
393
394        let chain = lineage.get_lineage_chain("v1");
395        assert_eq!(chain, vec!["v1"]);
396    }
397
398    #[test]
399    fn test_model_metadata_with_tags() {
400        let mut tags = HashMap::new();
401        tags.insert("env".to_string(), "production".to_string());
402        tags.insert("owner".to_string(), "team-ml".to_string());
403
404        let model = ModelMetadata {
405            model_id: "v1".to_string(),
406            version: "1.0.0".to_string(),
407            accuracy: 0.95,
408            created_at: 1700000000,
409            config_hash: "abc123".to_string(),
410            tags,
411        };
412
413        assert_eq!(model.tags.len(), 2);
414        assert_eq!(model.created_at, 1700000000);
415    }
416
417    #[test]
418    fn test_derivation_clone() {
419        let d = Derivation {
420            parent_id: "v1".to_string(),
421            child_id: "v2".to_string(),
422            change_type: ChangeType::Merge,
423            description: "merged models".to_string(),
424        };
425        let cloned = d.clone();
426        assert_eq!(d.parent_id, cloned.parent_id);
427        assert_eq!(d.change_type, cloned.change_type);
428    }
429
430    #[test]
431    fn test_model_comparison_clone() {
432        let cmp = ModelComparison {
433            model_a: "v1".to_string(),
434            model_b: "v2".to_string(),
435            accuracy_delta: 0.05,
436            is_improvement: true,
437        };
438        let cloned = cmp.clone();
439        assert_eq!(cmp.accuracy_delta, cloned.accuracy_delta);
440    }
441
442    #[test]
443    fn test_model_lineage_default() {
444        let lineage = ModelLineage::default();
445        assert!(lineage.models.is_empty());
446        assert!(lineage.derivations.is_empty());
447    }
448}