use ipfrs_core::Cid;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ProvenanceError {
#[error("Provenance record not found: {0}")]
RecordNotFound(String),
#[error("Circular dependency detected")]
CircularDependency,
#[error("Invalid provenance chain")]
InvalidChain,
#[error("Missing required metadata: {0}")]
MissingMetadata(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum License {
MIT,
Apache2,
GPLv3,
BSD3Clause,
CCBY,
CCBYSA,
Proprietary,
Custom(String),
Unknown,
}
impl std::fmt::Display for License {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
License::MIT => write!(f, "MIT"),
License::Apache2 => write!(f, "Apache-2.0"),
License::GPLv3 => write!(f, "GPL-3.0"),
License::BSD3Clause => write!(f, "BSD-3-Clause"),
License::CCBY => write!(f, "CC-BY"),
License::CCBYSA => write!(f, "CC-BY-SA"),
License::Proprietary => write!(f, "Proprietary"),
License::Custom(s) => write!(f, "Custom: {}", s),
License::Unknown => write!(f, "Unknown"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Attribution {
pub name: String,
pub contact: Option<String>,
pub organization: Option<String>,
pub role: String,
pub timestamp: i64,
}
impl Attribution {
pub fn new(name: String, role: String) -> Self {
Self {
name,
contact: None,
organization: None,
role,
timestamp: chrono::Utc::now().timestamp(),
}
}
pub fn with_contact(mut self, contact: String) -> Self {
self.contact = Some(contact);
self
}
pub fn with_organization(mut self, organization: String) -> Self {
self.organization = Some(organization);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetProvenance {
#[serde(serialize_with = "crate::serialize_cid")]
#[serde(deserialize_with = "crate::deserialize_cid")]
pub dataset_cid: Cid,
pub name: String,
pub version: String,
pub license: License,
pub attributions: Vec<Attribution>,
pub sources: Vec<String>,
pub description: Option<String>,
pub created_at: i64,
}
impl DatasetProvenance {
pub fn new(dataset_cid: Cid, name: String, version: String, license: License) -> Self {
Self {
dataset_cid,
name,
version,
license,
attributions: Vec::new(),
sources: Vec::new(),
description: None,
created_at: chrono::Utc::now().timestamp(),
}
}
pub fn add_attribution(mut self, attribution: Attribution) -> Self {
self.attributions.push(attribution);
self
}
pub fn add_source(mut self, source: String) -> Self {
self.sources.push(source);
self
}
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hyperparameters {
pub learning_rate: Option<f32>,
pub batch_size: Option<usize>,
pub epochs: Option<usize>,
pub optimizer: Option<String>,
pub custom: HashMap<String, String>,
}
impl Hyperparameters {
pub fn new() -> Self {
Self {
learning_rate: None,
batch_size: None,
epochs: None,
optimizer: None,
custom: HashMap::new(),
}
}
pub fn with_learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = Some(lr);
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn with_epochs(mut self, epochs: usize) -> Self {
self.epochs = Some(epochs);
self
}
pub fn with_optimizer(mut self, optimizer: String) -> Self {
self.optimizer = Some(optimizer);
self
}
pub fn add_param(mut self, key: String, value: String) -> Self {
self.custom.insert(key, value);
self
}
}
impl Default for Hyperparameters {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingProvenance {
#[serde(serialize_with = "crate::serialize_cid")]
#[serde(deserialize_with = "crate::deserialize_cid")]
pub model_cid: Cid,
#[serde(serialize_with = "serialize_optional_cid")]
#[serde(deserialize_with = "deserialize_optional_cid")]
pub parent_model: Option<Cid>,
#[serde(serialize_with = "serialize_cid_vec")]
#[serde(deserialize_with = "deserialize_cid_vec")]
pub training_datasets: Vec<Cid>,
#[serde(serialize_with = "serialize_cid_vec")]
#[serde(deserialize_with = "deserialize_cid_vec")]
pub validation_datasets: Vec<Cid>,
pub hyperparameters: Hyperparameters,
pub metrics: HashMap<String, f32>,
pub attributions: Vec<Attribution>,
pub license: License,
pub started_at: i64,
pub completed_at: Option<i64>,
pub code_repository: Option<String>,
pub code_commit: Option<String>,
pub hardware: Option<String>,
pub framework: Option<String>,
}
impl TrainingProvenance {
pub fn new(model_cid: Cid, training_datasets: Vec<Cid>, license: License) -> Self {
Self {
model_cid,
parent_model: None,
training_datasets,
validation_datasets: Vec::new(),
hyperparameters: Hyperparameters::new(),
metrics: HashMap::new(),
attributions: Vec::new(),
license,
started_at: chrono::Utc::now().timestamp(),
completed_at: None,
code_repository: None,
code_commit: None,
hardware: None,
framework: None,
}
}
pub fn with_parent(mut self, parent_cid: Cid) -> Self {
self.parent_model = Some(parent_cid);
self
}
pub fn add_validation_dataset(mut self, dataset_cid: Cid) -> Self {
self.validation_datasets.push(dataset_cid);
self
}
pub fn with_hyperparameters(mut self, hyperparameters: Hyperparameters) -> Self {
self.hyperparameters = hyperparameters;
self
}
pub fn add_metric(mut self, name: String, value: f32) -> Self {
self.metrics.insert(name, value);
self
}
pub fn add_attribution(mut self, attribution: Attribution) -> Self {
self.attributions.push(attribution);
self
}
pub fn complete(mut self) -> Self {
self.completed_at = Some(chrono::Utc::now().timestamp());
self
}
pub fn with_code_repository(mut self, repo: String, commit: String) -> Self {
self.code_repository = Some(repo);
self.code_commit = Some(commit);
self
}
pub fn with_hardware(mut self, hardware: String) -> Self {
self.hardware = Some(hardware);
self
}
pub fn with_framework(mut self, framework: String) -> Self {
self.framework = Some(framework);
self
}
}
#[derive(Debug, Clone)]
pub struct ProvenanceGraph {
datasets: HashMap<String, DatasetProvenance>,
training_records: HashMap<String, TrainingProvenance>,
}
impl ProvenanceGraph {
pub fn new() -> Self {
Self {
datasets: HashMap::new(),
training_records: HashMap::new(),
}
}
pub fn add_dataset(&mut self, provenance: DatasetProvenance) {
self.datasets
.insert(provenance.dataset_cid.to_string(), provenance);
}
pub fn add_training(&mut self, provenance: TrainingProvenance) {
self.training_records
.insert(provenance.model_cid.to_string(), provenance);
}
pub fn get_dataset(&self, dataset_cid: &Cid) -> Option<&DatasetProvenance> {
self.datasets.get(&dataset_cid.to_string())
}
pub fn get_training(&self, model_cid: &Cid) -> Option<&TrainingProvenance> {
self.training_records.get(&model_cid.to_string())
}
pub fn trace_lineage(&self, model_cid: &Cid) -> Result<LineageTrace, ProvenanceError> {
let mut visited = HashSet::new();
let mut datasets = Vec::new();
let mut models = Vec::new();
self.trace_recursive(model_cid, &mut visited, &mut datasets, &mut models)?;
Ok(LineageTrace {
target_model: *model_cid,
datasets,
models,
})
}
fn trace_recursive(
&self,
model_cid: &Cid,
visited: &mut HashSet<Cid>,
datasets: &mut Vec<Cid>,
models: &mut Vec<Cid>,
) -> Result<(), ProvenanceError> {
if visited.contains(model_cid) {
return Err(ProvenanceError::CircularDependency);
}
visited.insert(*model_cid);
let training = self
.get_training(model_cid)
.ok_or_else(|| ProvenanceError::RecordNotFound(model_cid.to_string()))?;
models.push(*model_cid);
for dataset_cid in &training.training_datasets {
if !datasets.contains(dataset_cid) {
datasets.push(*dataset_cid);
}
}
for dataset_cid in &training.validation_datasets {
if !datasets.contains(dataset_cid) {
datasets.push(*dataset_cid);
}
}
if let Some(parent_cid) = training.parent_model {
self.trace_recursive(&parent_cid, visited, datasets, models)?;
}
Ok(())
}
pub fn get_all_attributions(
&self,
model_cid: &Cid,
) -> Result<Vec<Attribution>, ProvenanceError> {
let lineage = self.trace_lineage(model_cid)?;
let mut attributions = Vec::new();
if let Some(training) = self.get_training(model_cid) {
attributions.extend(training.attributions.clone());
}
for dataset_cid in &lineage.datasets {
if let Some(dataset) = self.get_dataset(dataset_cid) {
attributions.extend(dataset.attributions.clone());
}
}
Ok(attributions)
}
pub fn get_all_licenses(&self, model_cid: &Cid) -> Result<HashSet<License>, ProvenanceError> {
let lineage = self.trace_lineage(model_cid)?;
let mut licenses = HashSet::new();
for model in &lineage.models {
if let Some(training) = self.get_training(model) {
licenses.insert(training.license.clone());
}
}
for dataset_cid in &lineage.datasets {
if let Some(dataset) = self.get_dataset(dataset_cid) {
licenses.insert(dataset.license.clone());
}
}
Ok(licenses)
}
pub fn is_reproducible(&self, model_cid: &Cid) -> bool {
if let Some(training) = self.get_training(model_cid) {
training.code_repository.is_some()
&& training.code_commit.is_some()
&& training.hyperparameters.learning_rate.is_some()
&& !training.training_datasets.is_empty()
} else {
false
}
}
}
impl Default for ProvenanceGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct LineageTrace {
pub target_model: Cid,
pub datasets: Vec<Cid>,
pub models: Vec<Cid>,
}
impl LineageTrace {
pub fn depth(&self) -> usize {
self.models.len()
}
pub fn dataset_count(&self) -> usize {
self.datasets.len()
}
}
fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
let strings: Vec<String> = cids.iter().map(|c| c.to_string()).collect();
strings.serialize(serializer)
}
fn deserialize_cid_vec<'de, D>(deserializer: D) -> Result<Vec<Cid>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
let strings = Vec::<String>::deserialize(deserializer)?;
strings
.into_iter()
.map(|s| s.parse().map_err(serde::de::Error::custom))
.collect()
}
fn serialize_optional_cid<S>(cid: &Option<Cid>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
match cid {
Some(c) => Some(c.to_string()).serialize(serializer),
None => None::<String>.serialize(serializer),
}
}
fn deserialize_optional_cid<'de, D>(deserializer: D) -> Result<Option<Cid>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
let opt = Option::<String>::deserialize(deserializer)?;
opt.map(|s| s.parse().map_err(serde::de::Error::custom))
.transpose()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attribution() {
let attr = Attribution::new("John Doe".to_string(), "data provider".to_string())
.with_contact("john@example.com".to_string())
.with_organization("Example Corp".to_string());
assert_eq!(attr.name, "John Doe");
assert_eq!(attr.contact, Some("john@example.com".to_string()));
assert_eq!(attr.organization, Some("Example Corp".to_string()));
}
#[test]
fn test_dataset_provenance() {
let dataset = DatasetProvenance::new(
Cid::default(),
"ImageNet".to_string(),
"1.0".to_string(),
License::CCBY,
)
.add_attribution(Attribution::new(
"Stanford".to_string(),
"creator".to_string(),
))
.add_source("https://example.com/imagenet".to_string())
.with_description("Large image dataset".to_string());
assert_eq!(dataset.name, "ImageNet");
assert_eq!(dataset.license, License::CCBY);
assert_eq!(dataset.attributions.len(), 1);
}
#[test]
fn test_hyperparameters() {
let hparams = Hyperparameters::new()
.with_learning_rate(0.001)
.with_batch_size(32)
.with_epochs(10)
.with_optimizer("Adam".to_string())
.add_param("weight_decay".to_string(), "0.0001".to_string());
assert_eq!(hparams.learning_rate, Some(0.001));
assert_eq!(hparams.batch_size, Some(32));
assert_eq!(hparams.epochs, Some(10));
}
#[test]
fn test_training_provenance() {
let training = TrainingProvenance::new(Cid::default(), vec![Cid::default()], License::MIT)
.with_hyperparameters(
Hyperparameters::new()
.with_learning_rate(0.001)
.with_batch_size(32),
)
.add_metric("accuracy".to_string(), 0.95)
.add_attribution(Attribution::new(
"Jane Doe".to_string(),
"trainer".to_string(),
))
.complete();
assert_eq!(training.training_datasets.len(), 1);
assert_eq!(training.metrics.len(), 1);
assert!(training.completed_at.is_some());
}
#[test]
fn test_provenance_graph() {
let mut graph = ProvenanceGraph::new();
let dataset_cid = Cid::default();
let dataset = DatasetProvenance::new(
dataset_cid,
"TestDataset".to_string(),
"1.0".to_string(),
License::MIT,
);
graph.add_dataset(dataset);
let model_cid = Cid::default();
let training = TrainingProvenance::new(model_cid, vec![dataset_cid], License::MIT);
graph.add_training(training);
assert!(graph.get_dataset(&dataset_cid).is_some());
assert!(graph.get_training(&model_cid).is_some());
}
#[test]
fn test_lineage_tracing() {
let mut graph = ProvenanceGraph::new();
let dataset_cid = Cid::default();
let dataset = DatasetProvenance::new(
dataset_cid,
"TestDataset".to_string(),
"1.0".to_string(),
License::MIT,
);
graph.add_dataset(dataset);
let model_cid = Cid::default();
let training = TrainingProvenance::new(model_cid, vec![dataset_cid], License::MIT);
graph.add_training(training);
let lineage = graph.trace_lineage(&model_cid).unwrap();
assert_eq!(lineage.depth(), 1);
assert_eq!(lineage.dataset_count(), 1);
}
#[test]
fn test_license_display() {
assert_eq!(License::MIT.to_string(), "MIT");
assert_eq!(License::Apache2.to_string(), "Apache-2.0");
assert_eq!(
License::Custom("Custom-1.0".to_string()).to_string(),
"Custom: Custom-1.0"
);
}
}