1use ipfrs_core::Cid;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
16pub enum ProvenanceError {
17 #[error("Provenance record not found: {0}")]
18 RecordNotFound(String),
19
20 #[error("Circular dependency detected")]
21 CircularDependency,
22
23 #[error("Invalid provenance chain")]
24 InvalidChain,
25
26 #[error("Missing required metadata: {0}")]
27 MissingMetadata(String),
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
32pub enum License {
33 MIT,
35 Apache2,
37 GPLv3,
39 BSD3Clause,
41 CCBY,
43 CCBYSA,
45 Proprietary,
47 Custom(String),
49 Unknown,
51}
52
53impl std::fmt::Display for License {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 License::MIT => write!(f, "MIT"),
57 License::Apache2 => write!(f, "Apache-2.0"),
58 License::GPLv3 => write!(f, "GPL-3.0"),
59 License::BSD3Clause => write!(f, "BSD-3-Clause"),
60 License::CCBY => write!(f, "CC-BY"),
61 License::CCBYSA => write!(f, "CC-BY-SA"),
62 License::Proprietary => write!(f, "Proprietary"),
63 License::Custom(s) => write!(f, "Custom: {}", s),
64 License::Unknown => write!(f, "Unknown"),
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct Attribution {
72 pub name: String,
74 pub contact: Option<String>,
76 pub organization: Option<String>,
78 pub role: String,
80 pub timestamp: i64,
82}
83
84impl Attribution {
85 pub fn new(name: String, role: String) -> Self {
87 Self {
88 name,
89 contact: None,
90 organization: None,
91 role,
92 timestamp: chrono::Utc::now().timestamp(),
93 }
94 }
95
96 pub fn with_contact(mut self, contact: String) -> Self {
98 self.contact = Some(contact);
99 self
100 }
101
102 pub fn with_organization(mut self, organization: String) -> Self {
104 self.organization = Some(organization);
105 self
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct DatasetProvenance {
112 #[serde(serialize_with = "crate::serialize_cid")]
114 #[serde(deserialize_with = "crate::deserialize_cid")]
115 pub dataset_cid: Cid,
116
117 pub name: String,
119
120 pub version: String,
122
123 pub license: License,
125
126 pub attributions: Vec<Attribution>,
128
129 pub sources: Vec<String>,
131
132 pub description: Option<String>,
134
135 pub created_at: i64,
137}
138
139impl DatasetProvenance {
140 pub fn new(dataset_cid: Cid, name: String, version: String, license: License) -> Self {
142 Self {
143 dataset_cid,
144 name,
145 version,
146 license,
147 attributions: Vec::new(),
148 sources: Vec::new(),
149 description: None,
150 created_at: chrono::Utc::now().timestamp(),
151 }
152 }
153
154 pub fn add_attribution(mut self, attribution: Attribution) -> Self {
156 self.attributions.push(attribution);
157 self
158 }
159
160 pub fn add_source(mut self, source: String) -> Self {
162 self.sources.push(source);
163 self
164 }
165
166 pub fn with_description(mut self, description: String) -> Self {
168 self.description = Some(description);
169 self
170 }
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct Hyperparameters {
176 pub learning_rate: Option<f32>,
178 pub batch_size: Option<usize>,
180 pub epochs: Option<usize>,
182 pub optimizer: Option<String>,
184 pub custom: HashMap<String, String>,
186}
187
188impl Hyperparameters {
189 pub fn new() -> Self {
191 Self {
192 learning_rate: None,
193 batch_size: None,
194 epochs: None,
195 optimizer: None,
196 custom: HashMap::new(),
197 }
198 }
199
200 pub fn with_learning_rate(mut self, lr: f32) -> Self {
202 self.learning_rate = Some(lr);
203 self
204 }
205
206 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
208 self.batch_size = Some(batch_size);
209 self
210 }
211
212 pub fn with_epochs(mut self, epochs: usize) -> Self {
214 self.epochs = Some(epochs);
215 self
216 }
217
218 pub fn with_optimizer(mut self, optimizer: String) -> Self {
220 self.optimizer = Some(optimizer);
221 self
222 }
223
224 pub fn add_param(mut self, key: String, value: String) -> Self {
226 self.custom.insert(key, value);
227 self
228 }
229}
230
231impl Default for Hyperparameters {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct TrainingProvenance {
240 #[serde(serialize_with = "crate::serialize_cid")]
242 #[serde(deserialize_with = "crate::deserialize_cid")]
243 pub model_cid: Cid,
244
245 #[serde(serialize_with = "serialize_optional_cid")]
247 #[serde(deserialize_with = "deserialize_optional_cid")]
248 pub parent_model: Option<Cid>,
249
250 #[serde(serialize_with = "serialize_cid_vec")]
252 #[serde(deserialize_with = "deserialize_cid_vec")]
253 pub training_datasets: Vec<Cid>,
254
255 #[serde(serialize_with = "serialize_cid_vec")]
257 #[serde(deserialize_with = "deserialize_cid_vec")]
258 pub validation_datasets: Vec<Cid>,
259
260 pub hyperparameters: Hyperparameters,
262
263 pub metrics: HashMap<String, f32>,
265
266 pub attributions: Vec<Attribution>,
268
269 pub license: License,
271
272 pub started_at: i64,
274
275 pub completed_at: Option<i64>,
277
278 pub code_repository: Option<String>,
280
281 pub code_commit: Option<String>,
283
284 pub hardware: Option<String>,
286
287 pub framework: Option<String>,
289}
290
291impl TrainingProvenance {
292 pub fn new(model_cid: Cid, training_datasets: Vec<Cid>, license: License) -> Self {
294 Self {
295 model_cid,
296 parent_model: None,
297 training_datasets,
298 validation_datasets: Vec::new(),
299 hyperparameters: Hyperparameters::new(),
300 metrics: HashMap::new(),
301 attributions: Vec::new(),
302 license,
303 started_at: chrono::Utc::now().timestamp(),
304 completed_at: None,
305 code_repository: None,
306 code_commit: None,
307 hardware: None,
308 framework: None,
309 }
310 }
311
312 pub fn with_parent(mut self, parent_cid: Cid) -> Self {
314 self.parent_model = Some(parent_cid);
315 self
316 }
317
318 pub fn add_validation_dataset(mut self, dataset_cid: Cid) -> Self {
320 self.validation_datasets.push(dataset_cid);
321 self
322 }
323
324 pub fn with_hyperparameters(mut self, hyperparameters: Hyperparameters) -> Self {
326 self.hyperparameters = hyperparameters;
327 self
328 }
329
330 pub fn add_metric(mut self, name: String, value: f32) -> Self {
332 self.metrics.insert(name, value);
333 self
334 }
335
336 pub fn add_attribution(mut self, attribution: Attribution) -> Self {
338 self.attributions.push(attribution);
339 self
340 }
341
342 pub fn complete(mut self) -> Self {
344 self.completed_at = Some(chrono::Utc::now().timestamp());
345 self
346 }
347
348 pub fn with_code_repository(mut self, repo: String, commit: String) -> Self {
350 self.code_repository = Some(repo);
351 self.code_commit = Some(commit);
352 self
353 }
354
355 pub fn with_hardware(mut self, hardware: String) -> Self {
357 self.hardware = Some(hardware);
358 self
359 }
360
361 pub fn with_framework(mut self, framework: String) -> Self {
363 self.framework = Some(framework);
364 self
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct ProvenanceGraph {
371 datasets: HashMap<String, DatasetProvenance>,
373
374 training_records: HashMap<String, TrainingProvenance>,
376}
377
378impl ProvenanceGraph {
379 pub fn new() -> Self {
381 Self {
382 datasets: HashMap::new(),
383 training_records: HashMap::new(),
384 }
385 }
386
387 pub fn add_dataset(&mut self, provenance: DatasetProvenance) {
389 self.datasets
390 .insert(provenance.dataset_cid.to_string(), provenance);
391 }
392
393 pub fn add_training(&mut self, provenance: TrainingProvenance) {
395 self.training_records
396 .insert(provenance.model_cid.to_string(), provenance);
397 }
398
399 pub fn get_dataset(&self, dataset_cid: &Cid) -> Option<&DatasetProvenance> {
401 self.datasets.get(&dataset_cid.to_string())
402 }
403
404 pub fn get_training(&self, model_cid: &Cid) -> Option<&TrainingProvenance> {
406 self.training_records.get(&model_cid.to_string())
407 }
408
409 pub fn trace_lineage(&self, model_cid: &Cid) -> Result<LineageTrace, ProvenanceError> {
411 let mut visited = HashSet::new();
412 let mut datasets = Vec::new();
413 let mut models = Vec::new();
414
415 self.trace_recursive(model_cid, &mut visited, &mut datasets, &mut models)?;
416
417 Ok(LineageTrace {
418 target_model: *model_cid,
419 datasets,
420 models,
421 })
422 }
423
424 fn trace_recursive(
426 &self,
427 model_cid: &Cid,
428 visited: &mut HashSet<Cid>,
429 datasets: &mut Vec<Cid>,
430 models: &mut Vec<Cid>,
431 ) -> Result<(), ProvenanceError> {
432 if visited.contains(model_cid) {
433 return Err(ProvenanceError::CircularDependency);
434 }
435
436 visited.insert(*model_cid);
437
438 let training = self
439 .get_training(model_cid)
440 .ok_or_else(|| ProvenanceError::RecordNotFound(model_cid.to_string()))?;
441
442 models.push(*model_cid);
443
444 for dataset_cid in &training.training_datasets {
446 if !datasets.contains(dataset_cid) {
447 datasets.push(*dataset_cid);
448 }
449 }
450
451 for dataset_cid in &training.validation_datasets {
452 if !datasets.contains(dataset_cid) {
453 datasets.push(*dataset_cid);
454 }
455 }
456
457 if let Some(parent_cid) = training.parent_model {
459 self.trace_recursive(&parent_cid, visited, datasets, models)?;
460 }
461
462 Ok(())
463 }
464
465 pub fn get_all_attributions(
467 &self,
468 model_cid: &Cid,
469 ) -> Result<Vec<Attribution>, ProvenanceError> {
470 let lineage = self.trace_lineage(model_cid)?;
471 let mut attributions = Vec::new();
472
473 if let Some(training) = self.get_training(model_cid) {
475 attributions.extend(training.attributions.clone());
476 }
477
478 for dataset_cid in &lineage.datasets {
480 if let Some(dataset) = self.get_dataset(dataset_cid) {
481 attributions.extend(dataset.attributions.clone());
482 }
483 }
484
485 Ok(attributions)
486 }
487
488 pub fn get_all_licenses(&self, model_cid: &Cid) -> Result<HashSet<License>, ProvenanceError> {
490 let lineage = self.trace_lineage(model_cid)?;
491 let mut licenses = HashSet::new();
492
493 for model in &lineage.models {
495 if let Some(training) = self.get_training(model) {
496 licenses.insert(training.license.clone());
497 }
498 }
499
500 for dataset_cid in &lineage.datasets {
502 if let Some(dataset) = self.get_dataset(dataset_cid) {
503 licenses.insert(dataset.license.clone());
504 }
505 }
506
507 Ok(licenses)
508 }
509
510 pub fn is_reproducible(&self, model_cid: &Cid) -> bool {
512 if let Some(training) = self.get_training(model_cid) {
513 training.code_repository.is_some()
515 && training.code_commit.is_some()
516 && training.hyperparameters.learning_rate.is_some()
517 && !training.training_datasets.is_empty()
518 } else {
519 false
520 }
521 }
522}
523
524impl Default for ProvenanceGraph {
525 fn default() -> Self {
526 Self::new()
527 }
528}
529
530#[derive(Debug, Clone)]
532pub struct LineageTrace {
533 pub target_model: Cid,
535 pub datasets: Vec<Cid>,
537 pub models: Vec<Cid>,
539}
540
541impl LineageTrace {
542 pub fn depth(&self) -> usize {
544 self.models.len()
545 }
546
547 pub fn dataset_count(&self) -> usize {
549 self.datasets.len()
550 }
551}
552
553fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> Result<S::Ok, S::Error>
555where
556 S: serde::Serializer,
557{
558 use serde::Serialize;
559 let strings: Vec<String> = cids.iter().map(|c| c.to_string()).collect();
560 strings.serialize(serializer)
561}
562
563fn deserialize_cid_vec<'de, D>(deserializer: D) -> Result<Vec<Cid>, D::Error>
564where
565 D: serde::Deserializer<'de>,
566{
567 use serde::Deserialize;
568 let strings = Vec::<String>::deserialize(deserializer)?;
569 strings
570 .into_iter()
571 .map(|s| s.parse().map_err(serde::de::Error::custom))
572 .collect()
573}
574
575fn serialize_optional_cid<S>(cid: &Option<Cid>, serializer: S) -> Result<S::Ok, S::Error>
576where
577 S: serde::Serializer,
578{
579 use serde::Serialize;
580 match cid {
581 Some(c) => Some(c.to_string()).serialize(serializer),
582 None => None::<String>.serialize(serializer),
583 }
584}
585
586fn deserialize_optional_cid<'de, D>(deserializer: D) -> Result<Option<Cid>, D::Error>
587where
588 D: serde::Deserializer<'de>,
589{
590 use serde::Deserialize;
591 let opt = Option::<String>::deserialize(deserializer)?;
592 opt.map(|s| s.parse().map_err(serde::de::Error::custom))
593 .transpose()
594}
595
596#[cfg(test)]
597mod tests {
598 use super::*;
599
600 #[test]
601 fn test_attribution() {
602 let attr = Attribution::new("John Doe".to_string(), "data provider".to_string())
603 .with_contact("john@example.com".to_string())
604 .with_organization("Example Corp".to_string());
605
606 assert_eq!(attr.name, "John Doe");
607 assert_eq!(attr.contact, Some("john@example.com".to_string()));
608 assert_eq!(attr.organization, Some("Example Corp".to_string()));
609 }
610
611 #[test]
612 fn test_dataset_provenance() {
613 let dataset = DatasetProvenance::new(
614 Cid::default(),
615 "ImageNet".to_string(),
616 "1.0".to_string(),
617 License::CCBY,
618 )
619 .add_attribution(Attribution::new(
620 "Stanford".to_string(),
621 "creator".to_string(),
622 ))
623 .add_source("https://example.com/imagenet".to_string())
624 .with_description("Large image dataset".to_string());
625
626 assert_eq!(dataset.name, "ImageNet");
627 assert_eq!(dataset.license, License::CCBY);
628 assert_eq!(dataset.attributions.len(), 1);
629 }
630
631 #[test]
632 fn test_hyperparameters() {
633 let hparams = Hyperparameters::new()
634 .with_learning_rate(0.001)
635 .with_batch_size(32)
636 .with_epochs(10)
637 .with_optimizer("Adam".to_string())
638 .add_param("weight_decay".to_string(), "0.0001".to_string());
639
640 assert_eq!(hparams.learning_rate, Some(0.001));
641 assert_eq!(hparams.batch_size, Some(32));
642 assert_eq!(hparams.epochs, Some(10));
643 }
644
645 #[test]
646 fn test_training_provenance() {
647 let training = TrainingProvenance::new(Cid::default(), vec![Cid::default()], License::MIT)
648 .with_hyperparameters(
649 Hyperparameters::new()
650 .with_learning_rate(0.001)
651 .with_batch_size(32),
652 )
653 .add_metric("accuracy".to_string(), 0.95)
654 .add_attribution(Attribution::new(
655 "Jane Doe".to_string(),
656 "trainer".to_string(),
657 ))
658 .complete();
659
660 assert_eq!(training.training_datasets.len(), 1);
661 assert_eq!(training.metrics.len(), 1);
662 assert!(training.completed_at.is_some());
663 }
664
665 #[test]
666 fn test_provenance_graph() {
667 let mut graph = ProvenanceGraph::new();
668
669 let dataset_cid = Cid::default();
670 let dataset = DatasetProvenance::new(
671 dataset_cid,
672 "TestDataset".to_string(),
673 "1.0".to_string(),
674 License::MIT,
675 );
676
677 graph.add_dataset(dataset);
678
679 let model_cid = Cid::default();
680 let training = TrainingProvenance::new(model_cid, vec![dataset_cid], License::MIT);
681
682 graph.add_training(training);
683
684 assert!(graph.get_dataset(&dataset_cid).is_some());
685 assert!(graph.get_training(&model_cid).is_some());
686 }
687
688 #[test]
689 fn test_lineage_tracing() {
690 let mut graph = ProvenanceGraph::new();
691
692 let dataset_cid = Cid::default();
693 let dataset = DatasetProvenance::new(
694 dataset_cid,
695 "TestDataset".to_string(),
696 "1.0".to_string(),
697 License::MIT,
698 );
699 graph.add_dataset(dataset);
700
701 let model_cid = Cid::default();
702 let training = TrainingProvenance::new(model_cid, vec![dataset_cid], License::MIT);
703 graph.add_training(training);
704
705 let lineage = graph.trace_lineage(&model_cid).unwrap();
706
707 assert_eq!(lineage.depth(), 1);
708 assert_eq!(lineage.dataset_count(), 1);
709 }
710
711 #[test]
712 fn test_license_display() {
713 assert_eq!(License::MIT.to_string(), "MIT");
714 assert_eq!(License::Apache2.to_string(), "Apache-2.0");
715 assert_eq!(
716 License::Custom("Custom-1.0".to_string()).to_string(),
717 "Custom: Custom-1.0"
718 );
719 }
720}