ipfrs_tensorlogic/
provenance.rs

1//! Provenance tracking for ML models
2//!
3//! This module provides comprehensive provenance tracking including:
4//! - Data lineage as Merkle DAG
5//! - Backward tracing to source data
6//! - Attribution metadata (contributors, datasets, licenses)
7//! - Training history and reproducibility
8
9use ipfrs_core::Cid;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use thiserror::Error;
13
14/// Errors that can occur during provenance operations
15#[derive(Debug, Error)]
16pub enum ProvenanceError {
17    #[error("Provenance record not found: {0}")]
18    RecordNotFound(String),
19
20    #[error("Circular dependency detected")]
21    CircularDependency,
22
23    #[error("Invalid provenance chain")]
24    InvalidChain,
25
26    #[error("Missing required metadata: {0}")]
27    MissingMetadata(String),
28}
29
30/// License types for datasets and models
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
32pub enum License {
33    /// MIT License
34    MIT,
35    /// Apache 2.0
36    Apache2,
37    /// GPL v3
38    GPLv3,
39    /// BSD 3-Clause
40    BSD3Clause,
41    /// Creative Commons Attribution
42    CCBY,
43    /// Creative Commons Attribution-ShareAlike
44    CCBYSA,
45    /// Proprietary
46    Proprietary,
47    /// Custom license
48    Custom(String),
49    /// Unknown license
50    Unknown,
51}
52
53impl std::fmt::Display for License {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            License::MIT => write!(f, "MIT"),
57            License::Apache2 => write!(f, "Apache-2.0"),
58            License::GPLv3 => write!(f, "GPL-3.0"),
59            License::BSD3Clause => write!(f, "BSD-3-Clause"),
60            License::CCBY => write!(f, "CC-BY"),
61            License::CCBYSA => write!(f, "CC-BY-SA"),
62            License::Proprietary => write!(f, "Proprietary"),
63            License::Custom(s) => write!(f, "Custom: {}", s),
64            License::Unknown => write!(f, "Unknown"),
65        }
66    }
67}
68
69/// Attribution information for contributors
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct Attribution {
72    /// Contributor name
73    pub name: String,
74    /// Email or contact
75    pub contact: Option<String>,
76    /// Organization
77    pub organization: Option<String>,
78    /// Role (e.g., "data provider", "model trainer", "code contributor")
79    pub role: String,
80    /// Contribution timestamp
81    pub timestamp: i64,
82}
83
84impl Attribution {
85    /// Create a new attribution
86    pub fn new(name: String, role: String) -> Self {
87        Self {
88            name,
89            contact: None,
90            organization: None,
91            role,
92            timestamp: chrono::Utc::now().timestamp(),
93        }
94    }
95
96    /// Add contact information
97    pub fn with_contact(mut self, contact: String) -> Self {
98        self.contact = Some(contact);
99        self
100    }
101
102    /// Add organization
103    pub fn with_organization(mut self, organization: String) -> Self {
104        self.organization = Some(organization);
105        self
106    }
107}
108
109/// Dataset provenance information
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct DatasetProvenance {
112    /// Dataset CID
113    #[serde(serialize_with = "crate::serialize_cid")]
114    #[serde(deserialize_with = "crate::deserialize_cid")]
115    pub dataset_cid: Cid,
116
117    /// Dataset name
118    pub name: String,
119
120    /// Dataset version
121    pub version: String,
122
123    /// License
124    pub license: License,
125
126    /// Attribution
127    pub attributions: Vec<Attribution>,
128
129    /// Source URLs (if applicable)
130    pub sources: Vec<String>,
131
132    /// Description
133    pub description: Option<String>,
134
135    /// Creation timestamp
136    pub created_at: i64,
137}
138
139impl DatasetProvenance {
140    /// Create a new dataset provenance record
141    pub fn new(dataset_cid: Cid, name: String, version: String, license: License) -> Self {
142        Self {
143            dataset_cid,
144            name,
145            version,
146            license,
147            attributions: Vec::new(),
148            sources: Vec::new(),
149            description: None,
150            created_at: chrono::Utc::now().timestamp(),
151        }
152    }
153
154    /// Add an attribution
155    pub fn add_attribution(mut self, attribution: Attribution) -> Self {
156        self.attributions.push(attribution);
157        self
158    }
159
160    /// Add a source URL
161    pub fn add_source(mut self, source: String) -> Self {
162        self.sources.push(source);
163        self
164    }
165
166    /// Add description
167    pub fn with_description(mut self, description: String) -> Self {
168        self.description = Some(description);
169        self
170    }
171}
172
173/// Hyperparameters for training
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct Hyperparameters {
176    /// Learning rate
177    pub learning_rate: Option<f32>,
178    /// Batch size
179    pub batch_size: Option<usize>,
180    /// Number of epochs
181    pub epochs: Option<usize>,
182    /// Optimizer name
183    pub optimizer: Option<String>,
184    /// Additional parameters
185    pub custom: HashMap<String, String>,
186}
187
188impl Hyperparameters {
189    /// Create new hyperparameters
190    pub fn new() -> Self {
191        Self {
192            learning_rate: None,
193            batch_size: None,
194            epochs: None,
195            optimizer: None,
196            custom: HashMap::new(),
197        }
198    }
199
200    /// Set learning rate
201    pub fn with_learning_rate(mut self, lr: f32) -> Self {
202        self.learning_rate = Some(lr);
203        self
204    }
205
206    /// Set batch size
207    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
208        self.batch_size = Some(batch_size);
209        self
210    }
211
212    /// Set epochs
213    pub fn with_epochs(mut self, epochs: usize) -> Self {
214        self.epochs = Some(epochs);
215        self
216    }
217
218    /// Set optimizer
219    pub fn with_optimizer(mut self, optimizer: String) -> Self {
220        self.optimizer = Some(optimizer);
221        self
222    }
223
224    /// Add custom parameter
225    pub fn add_param(mut self, key: String, value: String) -> Self {
226        self.custom.insert(key, value);
227        self
228    }
229}
230
231impl Default for Hyperparameters {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237/// Training provenance for a model
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct TrainingProvenance {
240    /// Model CID
241    #[serde(serialize_with = "crate::serialize_cid")]
242    #[serde(deserialize_with = "crate::deserialize_cid")]
243    pub model_cid: Cid,
244
245    /// Parent model CID (if fine-tuning or transfer learning)
246    #[serde(serialize_with = "serialize_optional_cid")]
247    #[serde(deserialize_with = "deserialize_optional_cid")]
248    pub parent_model: Option<Cid>,
249
250    /// Training datasets
251    #[serde(serialize_with = "serialize_cid_vec")]
252    #[serde(deserialize_with = "deserialize_cid_vec")]
253    pub training_datasets: Vec<Cid>,
254
255    /// Validation datasets
256    #[serde(serialize_with = "serialize_cid_vec")]
257    #[serde(deserialize_with = "deserialize_cid_vec")]
258    pub validation_datasets: Vec<Cid>,
259
260    /// Hyperparameters
261    pub hyperparameters: Hyperparameters,
262
263    /// Training metrics (final)
264    pub metrics: HashMap<String, f32>,
265
266    /// Attribution
267    pub attributions: Vec<Attribution>,
268
269    /// License
270    pub license: License,
271
272    /// Training start time
273    pub started_at: i64,
274
275    /// Training end time
276    pub completed_at: Option<i64>,
277
278    /// Code repository (if applicable)
279    pub code_repository: Option<String>,
280
281    /// Code commit hash
282    pub code_commit: Option<String>,
283
284    /// Hardware used (e.g., "8x NVIDIA A100")
285    pub hardware: Option<String>,
286
287    /// Training framework (e.g., "PyTorch 2.0")
288    pub framework: Option<String>,
289}
290
291impl TrainingProvenance {
292    /// Create a new training provenance record
293    pub fn new(model_cid: Cid, training_datasets: Vec<Cid>, license: License) -> Self {
294        Self {
295            model_cid,
296            parent_model: None,
297            training_datasets,
298            validation_datasets: Vec::new(),
299            hyperparameters: Hyperparameters::new(),
300            metrics: HashMap::new(),
301            attributions: Vec::new(),
302            license,
303            started_at: chrono::Utc::now().timestamp(),
304            completed_at: None,
305            code_repository: None,
306            code_commit: None,
307            hardware: None,
308            framework: None,
309        }
310    }
311
312    /// Set parent model
313    pub fn with_parent(mut self, parent_cid: Cid) -> Self {
314        self.parent_model = Some(parent_cid);
315        self
316    }
317
318    /// Add validation dataset
319    pub fn add_validation_dataset(mut self, dataset_cid: Cid) -> Self {
320        self.validation_datasets.push(dataset_cid);
321        self
322    }
323
324    /// Set hyperparameters
325    pub fn with_hyperparameters(mut self, hyperparameters: Hyperparameters) -> Self {
326        self.hyperparameters = hyperparameters;
327        self
328    }
329
330    /// Add metric
331    pub fn add_metric(mut self, name: String, value: f32) -> Self {
332        self.metrics.insert(name, value);
333        self
334    }
335
336    /// Add attribution
337    pub fn add_attribution(mut self, attribution: Attribution) -> Self {
338        self.attributions.push(attribution);
339        self
340    }
341
342    /// Mark training as complete
343    pub fn complete(mut self) -> Self {
344        self.completed_at = Some(chrono::Utc::now().timestamp());
345        self
346    }
347
348    /// Set code repository
349    pub fn with_code_repository(mut self, repo: String, commit: String) -> Self {
350        self.code_repository = Some(repo);
351        self.code_commit = Some(commit);
352        self
353    }
354
355    /// Set hardware info
356    pub fn with_hardware(mut self, hardware: String) -> Self {
357        self.hardware = Some(hardware);
358        self
359    }
360
361    /// Set framework
362    pub fn with_framework(mut self, framework: String) -> Self {
363        self.framework = Some(framework);
364        self
365    }
366}
367
368/// Complete provenance graph for tracking lineage
369#[derive(Debug, Clone)]
370pub struct ProvenanceGraph {
371    /// Dataset provenance records
372    datasets: HashMap<String, DatasetProvenance>,
373
374    /// Training provenance records
375    training_records: HashMap<String, TrainingProvenance>,
376}
377
378impl ProvenanceGraph {
379    /// Create a new provenance graph
380    pub fn new() -> Self {
381        Self {
382            datasets: HashMap::new(),
383            training_records: HashMap::new(),
384        }
385    }
386
387    /// Add a dataset provenance record
388    pub fn add_dataset(&mut self, provenance: DatasetProvenance) {
389        self.datasets
390            .insert(provenance.dataset_cid.to_string(), provenance);
391    }
392
393    /// Add a training provenance record
394    pub fn add_training(&mut self, provenance: TrainingProvenance) {
395        self.training_records
396            .insert(provenance.model_cid.to_string(), provenance);
397    }
398
399    /// Get dataset provenance
400    pub fn get_dataset(&self, dataset_cid: &Cid) -> Option<&DatasetProvenance> {
401        self.datasets.get(&dataset_cid.to_string())
402    }
403
404    /// Get training provenance
405    pub fn get_training(&self, model_cid: &Cid) -> Option<&TrainingProvenance> {
406        self.training_records.get(&model_cid.to_string())
407    }
408
409    /// Trace lineage backward from a model to all source datasets
410    pub fn trace_lineage(&self, model_cid: &Cid) -> Result<LineageTrace, ProvenanceError> {
411        let mut visited = HashSet::new();
412        let mut datasets = Vec::new();
413        let mut models = Vec::new();
414
415        self.trace_recursive(model_cid, &mut visited, &mut datasets, &mut models)?;
416
417        Ok(LineageTrace {
418            target_model: *model_cid,
419            datasets,
420            models,
421        })
422    }
423
424    /// Recursive helper for tracing lineage
425    fn trace_recursive(
426        &self,
427        model_cid: &Cid,
428        visited: &mut HashSet<Cid>,
429        datasets: &mut Vec<Cid>,
430        models: &mut Vec<Cid>,
431    ) -> Result<(), ProvenanceError> {
432        if visited.contains(model_cid) {
433            return Err(ProvenanceError::CircularDependency);
434        }
435
436        visited.insert(*model_cid);
437
438        let training = self
439            .get_training(model_cid)
440            .ok_or_else(|| ProvenanceError::RecordNotFound(model_cid.to_string()))?;
441
442        models.push(*model_cid);
443
444        // Add datasets
445        for dataset_cid in &training.training_datasets {
446            if !datasets.contains(dataset_cid) {
447                datasets.push(*dataset_cid);
448            }
449        }
450
451        for dataset_cid in &training.validation_datasets {
452            if !datasets.contains(dataset_cid) {
453                datasets.push(*dataset_cid);
454            }
455        }
456
457        // Recursively trace parent model
458        if let Some(parent_cid) = training.parent_model {
459            self.trace_recursive(&parent_cid, visited, datasets, models)?;
460        }
461
462        Ok(())
463    }
464
465    /// Get all attributions for a model (including from datasets)
466    pub fn get_all_attributions(
467        &self,
468        model_cid: &Cid,
469    ) -> Result<Vec<Attribution>, ProvenanceError> {
470        let lineage = self.trace_lineage(model_cid)?;
471        let mut attributions = Vec::new();
472
473        // Get model attributions
474        if let Some(training) = self.get_training(model_cid) {
475            attributions.extend(training.attributions.clone());
476        }
477
478        // Get dataset attributions
479        for dataset_cid in &lineage.datasets {
480            if let Some(dataset) = self.get_dataset(dataset_cid) {
481                attributions.extend(dataset.attributions.clone());
482            }
483        }
484
485        Ok(attributions)
486    }
487
488    /// Get all licenses in the lineage
489    pub fn get_all_licenses(&self, model_cid: &Cid) -> Result<HashSet<License>, ProvenanceError> {
490        let lineage = self.trace_lineage(model_cid)?;
491        let mut licenses = HashSet::new();
492
493        // Get model licenses
494        for model in &lineage.models {
495            if let Some(training) = self.get_training(model) {
496                licenses.insert(training.license.clone());
497            }
498        }
499
500        // Get dataset licenses
501        for dataset_cid in &lineage.datasets {
502            if let Some(dataset) = self.get_dataset(dataset_cid) {
503                licenses.insert(dataset.license.clone());
504            }
505        }
506
507        Ok(licenses)
508    }
509
510    /// Check if lineage is reproducible (has all necessary metadata)
511    pub fn is_reproducible(&self, model_cid: &Cid) -> bool {
512        if let Some(training) = self.get_training(model_cid) {
513            // Check for required metadata
514            training.code_repository.is_some()
515                && training.code_commit.is_some()
516                && training.hyperparameters.learning_rate.is_some()
517                && !training.training_datasets.is_empty()
518        } else {
519            false
520        }
521    }
522}
523
524impl Default for ProvenanceGraph {
525    fn default() -> Self {
526        Self::new()
527    }
528}
529
530/// Result of lineage tracing
531#[derive(Debug, Clone)]
532pub struct LineageTrace {
533    /// Target model
534    pub target_model: Cid,
535    /// All datasets in the lineage
536    pub datasets: Vec<Cid>,
537    /// All models in the lineage (including target)
538    pub models: Vec<Cid>,
539}
540
541impl LineageTrace {
542    /// Get the depth of the lineage (number of model generations)
543    pub fn depth(&self) -> usize {
544        self.models.len()
545    }
546
547    /// Get the number of unique datasets
548    pub fn dataset_count(&self) -> usize {
549        self.datasets.len()
550    }
551}
552
553// Helper functions for serializing/deserializing Vec<Cid>
554fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> Result<S::Ok, S::Error>
555where
556    S: serde::Serializer,
557{
558    use serde::Serialize;
559    let strings: Vec<String> = cids.iter().map(|c| c.to_string()).collect();
560    strings.serialize(serializer)
561}
562
563fn deserialize_cid_vec<'de, D>(deserializer: D) -> Result<Vec<Cid>, D::Error>
564where
565    D: serde::Deserializer<'de>,
566{
567    use serde::Deserialize;
568    let strings = Vec::<String>::deserialize(deserializer)?;
569    strings
570        .into_iter()
571        .map(|s| s.parse().map_err(serde::de::Error::custom))
572        .collect()
573}
574
575fn serialize_optional_cid<S>(cid: &Option<Cid>, serializer: S) -> Result<S::Ok, S::Error>
576where
577    S: serde::Serializer,
578{
579    use serde::Serialize;
580    match cid {
581        Some(c) => Some(c.to_string()).serialize(serializer),
582        None => None::<String>.serialize(serializer),
583    }
584}
585
586fn deserialize_optional_cid<'de, D>(deserializer: D) -> Result<Option<Cid>, D::Error>
587where
588    D: serde::Deserializer<'de>,
589{
590    use serde::Deserialize;
591    let opt = Option::<String>::deserialize(deserializer)?;
592    opt.map(|s| s.parse().map_err(serde::de::Error::custom))
593        .transpose()
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599
600    #[test]
601    fn test_attribution() {
602        let attr = Attribution::new("John Doe".to_string(), "data provider".to_string())
603            .with_contact("john@example.com".to_string())
604            .with_organization("Example Corp".to_string());
605
606        assert_eq!(attr.name, "John Doe");
607        assert_eq!(attr.contact, Some("john@example.com".to_string()));
608        assert_eq!(attr.organization, Some("Example Corp".to_string()));
609    }
610
611    #[test]
612    fn test_dataset_provenance() {
613        let dataset = DatasetProvenance::new(
614            Cid::default(),
615            "ImageNet".to_string(),
616            "1.0".to_string(),
617            License::CCBY,
618        )
619        .add_attribution(Attribution::new(
620            "Stanford".to_string(),
621            "creator".to_string(),
622        ))
623        .add_source("https://example.com/imagenet".to_string())
624        .with_description("Large image dataset".to_string());
625
626        assert_eq!(dataset.name, "ImageNet");
627        assert_eq!(dataset.license, License::CCBY);
628        assert_eq!(dataset.attributions.len(), 1);
629    }
630
631    #[test]
632    fn test_hyperparameters() {
633        let hparams = Hyperparameters::new()
634            .with_learning_rate(0.001)
635            .with_batch_size(32)
636            .with_epochs(10)
637            .with_optimizer("Adam".to_string())
638            .add_param("weight_decay".to_string(), "0.0001".to_string());
639
640        assert_eq!(hparams.learning_rate, Some(0.001));
641        assert_eq!(hparams.batch_size, Some(32));
642        assert_eq!(hparams.epochs, Some(10));
643    }
644
645    #[test]
646    fn test_training_provenance() {
647        let training = TrainingProvenance::new(Cid::default(), vec![Cid::default()], License::MIT)
648            .with_hyperparameters(
649                Hyperparameters::new()
650                    .with_learning_rate(0.001)
651                    .with_batch_size(32),
652            )
653            .add_metric("accuracy".to_string(), 0.95)
654            .add_attribution(Attribution::new(
655                "Jane Doe".to_string(),
656                "trainer".to_string(),
657            ))
658            .complete();
659
660        assert_eq!(training.training_datasets.len(), 1);
661        assert_eq!(training.metrics.len(), 1);
662        assert!(training.completed_at.is_some());
663    }
664
665    #[test]
666    fn test_provenance_graph() {
667        let mut graph = ProvenanceGraph::new();
668
669        let dataset_cid = Cid::default();
670        let dataset = DatasetProvenance::new(
671            dataset_cid,
672            "TestDataset".to_string(),
673            "1.0".to_string(),
674            License::MIT,
675        );
676
677        graph.add_dataset(dataset);
678
679        let model_cid = Cid::default();
680        let training = TrainingProvenance::new(model_cid, vec![dataset_cid], License::MIT);
681
682        graph.add_training(training);
683
684        assert!(graph.get_dataset(&dataset_cid).is_some());
685        assert!(graph.get_training(&model_cid).is_some());
686    }
687
688    #[test]
689    fn test_lineage_tracing() {
690        let mut graph = ProvenanceGraph::new();
691
692        let dataset_cid = Cid::default();
693        let dataset = DatasetProvenance::new(
694            dataset_cid,
695            "TestDataset".to_string(),
696            "1.0".to_string(),
697            License::MIT,
698        );
699        graph.add_dataset(dataset);
700
701        let model_cid = Cid::default();
702        let training = TrainingProvenance::new(model_cid, vec![dataset_cid], License::MIT);
703        graph.add_training(training);
704
705        let lineage = graph.trace_lineage(&model_cid).unwrap();
706
707        assert_eq!(lineage.depth(), 1);
708        assert_eq!(lineage.dataset_count(), 1);
709    }
710
711    #[test]
712    fn test_license_display() {
713        assert_eq!(License::MIT.to_string(), "MIT");
714        assert_eq!(License::Apache2.to_string(), "Apache-2.0");
715        assert_eq!(
716            License::Custom("Custom-1.0".to_string()).to_string(),
717            "Custom: Custom-1.0"
718        );
719    }
720}