use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use torsh_core::error::TorshError;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum LineageRelation {
DerivedFrom,
TrainedFrom,
QuantizedFrom,
PrunedFrom,
DistilledFrom,
ConvertedFrom,
MergedFrom,
CheckpointOf,
DependsOn,
Custom(String),
}
impl LineageRelation {
pub fn description(&self) -> String {
match self {
Self::DerivedFrom => "derived from".to_string(),
Self::TrainedFrom => "trained from".to_string(),
Self::QuantizedFrom => "quantized from".to_string(),
Self::PrunedFrom => "pruned from".to_string(),
Self::DistilledFrom => "distilled from".to_string(),
Self::ConvertedFrom => "converted from".to_string(),
Self::MergedFrom => "merged from".to_string(),
Self::CheckpointOf => "checkpoint of".to_string(),
Self::DependsOn => "depends on".to_string(),
Self::Custom(desc) => desc.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProvenanceInfo {
pub package_id: String,
pub creator: String,
pub creation_time: DateTime<Utc>,
pub source_url: Option<String>,
pub source_commit: Option<String>,
pub build_environment: HashMap<String, String>,
pub description: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LineageEdge {
pub from: String,
pub to: String,
pub relation: LineageRelation,
pub timestamp: DateTime<Utc>,
pub metadata: HashMap<String, String>,
pub description: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformationRecord {
pub package_id: String,
pub operation: String,
pub timestamp: DateTime<Utc>,
pub performed_by: String,
pub parameters: HashMap<String, String>,
pub result: String,
pub duration_secs: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ComplianceLevel {
None,
Internal,
Industry,
Regulatory,
CriticalSecurity,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceMetadata {
pub package_id: String,
pub level: ComplianceLevel,
pub tags: Vec<String>,
pub certifications: Vec<String>,
pub data_classification: String,
pub retention_days: Option<u32>,
pub access_restrictions: Vec<String>,
pub audit_required: bool,
pub last_audit: Option<DateTime<Utc>>,
pub next_audit_due: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LineageQueryResult {
pub packages: Vec<String>,
pub edges: Vec<LineageEdge>,
pub provenance: HashMap<String, ProvenanceInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceReport {
pub generated_at: DateTime<Utc>,
pub total_packages: usize,
pub compliant_packages: usize,
pub non_compliant_packages: usize,
pub needs_review: Vec<String>,
pub issues: Vec<ComplianceIssue>,
pub recommendations: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceIssue {
pub package_id: String,
pub severity: IssueSeverity,
pub description: String,
pub remediation: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum IssueSeverity {
Low,
Medium,
High,
Critical,
}
pub struct LineageTracker {
provenance: HashMap<String, ProvenanceInfo>,
edges: Vec<LineageEdge>,
forward_graph: HashMap<String, Vec<usize>>,
backward_graph: HashMap<String, Vec<usize>>,
transformations: HashMap<String, Vec<TransformationRecord>>,
compliance: HashMap<String, ComplianceMetadata>,
}
impl LineageTracker {
pub fn new() -> Self {
Self {
provenance: HashMap::new(),
edges: Vec::new(),
forward_graph: HashMap::new(),
backward_graph: HashMap::new(),
transformations: HashMap::new(),
compliance: HashMap::new(),
}
}
pub fn record_provenance(&mut self, info: ProvenanceInfo) {
let package_id = info.package_id.clone();
self.provenance.insert(package_id, info);
}
pub fn add_lineage(
&mut self,
from: String,
to: String,
relation: LineageRelation,
description: String,
) -> Result<(), TorshError> {
if self.would_create_cycle(&from, &to) {
return Err(TorshError::InvalidArgument(format!(
"Adding edge from {} to {} would create a cycle",
from, to
)));
}
let edge = LineageEdge {
from: from.clone(),
to: to.clone(),
relation,
timestamp: Utc::now(),
metadata: HashMap::new(),
description,
};
let edge_idx = self.edges.len();
self.edges.push(edge);
self.forward_graph
.entry(from.clone())
.or_insert_with(Vec::new)
.push(edge_idx);
self.backward_graph
.entry(to)
.or_insert_with(Vec::new)
.push(edge_idx);
Ok(())
}
pub fn add_lineage_with_metadata(
&mut self,
from: String,
to: String,
relation: LineageRelation,
description: String,
metadata: HashMap<String, String>,
) -> Result<(), TorshError> {
self.add_lineage(from.clone(), to.clone(), relation, description)?;
if let Some(edge) = self.edges.last_mut() {
edge.metadata = metadata;
}
Ok(())
}
pub fn record_transformation(&mut self, record: TransformationRecord) {
let package_id = record.package_id.clone();
self.transformations
.entry(package_id)
.or_insert_with(Vec::new)
.push(record);
}
pub fn set_compliance(&mut self, metadata: ComplianceMetadata) {
let package_id = metadata.package_id.clone();
self.compliance.insert(package_id, metadata);
}
pub fn get_provenance(&self, package_id: &str) -> Option<&ProvenanceInfo> {
self.provenance.get(package_id)
}
pub fn get_ancestors(&self, package_id: &str) -> Vec<String> {
let mut ancestors = Vec::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(package_id.to_string());
while let Some(current) = queue.pop_front() {
if !visited.insert(current.clone()) {
continue;
}
if let Some(edge_indices) = self.backward_graph.get(¤t) {
for &idx in edge_indices {
let edge = &self.edges[idx];
ancestors.push(edge.from.clone());
queue.push_back(edge.from.clone());
}
}
}
ancestors
}
pub fn get_descendants(&self, package_id: &str) -> Vec<String> {
let mut descendants = Vec::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(package_id.to_string());
while let Some(current) = queue.pop_front() {
if !visited.insert(current.clone()) {
continue;
}
if let Some(edge_indices) = self.forward_graph.get(¤t) {
for &idx in edge_indices {
let edge = &self.edges[idx];
descendants.push(edge.to.clone());
queue.push_back(edge.to.clone());
}
}
}
descendants
}
pub fn get_lineage(&self, package_id: &str) -> LineageQueryResult {
let mut packages = HashSet::new();
packages.insert(package_id.to_string());
let ancestors = self.get_ancestors(package_id);
let descendants = self.get_descendants(package_id);
packages.extend(ancestors);
packages.extend(descendants);
let relevant_edges: Vec<LineageEdge> = self
.edges
.iter()
.filter(|edge| packages.contains(&edge.from) || packages.contains(&edge.to))
.cloned()
.collect();
let provenance: HashMap<String, ProvenanceInfo> = packages
.iter()
.filter_map(|id| self.provenance.get(id).map(|p| (id.clone(), p.clone())))
.collect();
LineageQueryResult {
packages: packages.into_iter().collect(),
edges: relevant_edges,
provenance,
}
}
pub fn get_transformations(&self, package_id: &str) -> Vec<&TransformationRecord> {
self.transformations
.get(package_id)
.map(|records| records.iter().collect())
.unwrap_or_default()
}
pub fn get_compliance(&self, package_id: &str) -> Option<&ComplianceMetadata> {
self.compliance.get(package_id)
}
pub fn generate_compliance_report(&self) -> ComplianceReport {
let mut issues = Vec::new();
let mut needs_review = Vec::new();
let mut compliant = 0;
let mut non_compliant = 0;
for (package_id, metadata) in &self.compliance {
if let Some(due_date) = metadata.next_audit_due {
if due_date < Utc::now() {
issues.push(ComplianceIssue {
package_id: package_id.clone(),
severity: IssueSeverity::High,
description: "Compliance audit overdue".to_string(),
remediation: "Schedule and complete compliance audit".to_string(),
});
non_compliant += 1;
} else {
compliant += 1;
}
} else if metadata.audit_required {
needs_review.push(package_id.clone());
}
if metadata.level == ComplianceLevel::Regulatory && metadata.certifications.is_empty() {
issues.push(ComplianceIssue {
package_id: package_id.clone(),
severity: IssueSeverity::Critical,
description: "Regulatory compliance required but no certifications found"
.to_string(),
remediation: "Obtain required compliance certifications".to_string(),
});
}
if !self.provenance.contains_key(package_id) {
issues.push(ComplianceIssue {
package_id: package_id.clone(),
severity: IssueSeverity::Medium,
description: "Missing provenance information".to_string(),
remediation: "Record complete provenance information for the package"
.to_string(),
});
}
}
issues.sort_by(|a, b| b.severity.cmp(&a.severity));
let recommendations = self.generate_recommendations(&issues);
ComplianceReport {
generated_at: Utc::now(),
total_packages: self.compliance.len(),
compliant_packages: compliant,
non_compliant_packages: non_compliant,
needs_review,
issues,
recommendations,
}
}
pub fn export_to_dot(&self, package_id: &str) -> String {
let lineage = self.get_lineage(package_id);
let mut dot = String::from("digraph PackageLineage {\n");
dot.push_str(" rankdir=LR;\n");
dot.push_str(" node [shape=box];\n\n");
for pkg_id in &lineage.packages {
let label = if let Some(prov) = lineage.provenance.get(pkg_id) {
format!("{}\\n{}", pkg_id, prov.creator)
} else {
pkg_id.clone()
};
dot.push_str(&format!(" \"{}\" [label=\"{}\"];\n", pkg_id, label));
}
dot.push('\n');
for edge in &lineage.edges {
let label = edge.relation.description();
dot.push_str(&format!(
" \"{}\" -> \"{}\" [label=\"{}\"];\n",
edge.from, edge.to, label
));
}
dot.push_str("}\n");
dot
}
pub fn export_to_json(&self, package_id: &str) -> Result<String, TorshError> {
let lineage = self.get_lineage(package_id);
serde_json::to_string_pretty(&lineage)
.map_err(|e| TorshError::SerializationError(e.to_string()))
}
pub fn get_statistics(&self) -> LineageStatistics {
let total_packages = self.provenance.len();
let total_edges = self.edges.len();
let total_transformations: usize = self.transformations.values().map(|v| v.len()).sum();
let packages_with_compliance = self.compliance.len();
let roots: Vec<String> = self
.provenance
.keys()
.filter(|id| !self.backward_graph.contains_key(*id))
.cloned()
.collect();
let leaves: Vec<String> = self
.provenance
.keys()
.filter(|id| !self.forward_graph.contains_key(*id))
.cloned()
.collect();
let avg_depth = if !roots.is_empty() {
let total_depth: usize = roots.iter().map(|r| self.calculate_depth(r)).sum();
total_depth as f64 / roots.len() as f64
} else {
0.0
};
LineageStatistics {
total_packages,
total_edges,
total_transformations,
packages_with_compliance,
root_packages: roots.len(),
leaf_packages: leaves.len(),
average_depth: avg_depth,
}
}
fn would_create_cycle(&self, from: &str, to: &str) -> bool {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(to.to_string());
while let Some(current) = queue.pop_front() {
if current == from {
return true;
}
if !visited.insert(current.clone()) {
continue;
}
if let Some(edge_indices) = self.forward_graph.get(¤t) {
for &idx in edge_indices {
let edge = &self.edges[idx];
queue.push_back(edge.to.clone());
}
}
}
false
}
fn calculate_depth(&self, package_id: &str) -> usize {
let mut max_depth = 0;
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back((package_id.to_string(), 0));
while let Some((current, depth)) = queue.pop_front() {
if !visited.insert(current.clone()) {
continue;
}
max_depth = max_depth.max(depth);
if let Some(edge_indices) = self.forward_graph.get(¤t) {
for &idx in edge_indices {
let edge = &self.edges[idx];
queue.push_back((edge.to.clone(), depth + 1));
}
}
}
max_depth
}
fn generate_recommendations(&self, issues: &[ComplianceIssue]) -> Vec<String> {
let mut recommendations = Vec::new();
let critical_count = issues
.iter()
.filter(|i| i.severity == IssueSeverity::Critical)
.count();
let high_count = issues
.iter()
.filter(|i| i.severity == IssueSeverity::High)
.count();
if critical_count > 0 {
recommendations.push(format!(
"Address {} critical compliance issues immediately",
critical_count
));
}
if high_count > 0 {
recommendations.push(format!(
"Schedule remediation for {} high-severity issues within 30 days",
high_count
));
}
if self.compliance.len() < self.provenance.len() {
recommendations
.push("Define compliance metadata for all packages in the lineage".to_string());
}
let missing_provenance = self.provenance.len();
let total_packages = self.forward_graph.len().max(self.backward_graph.len());
if missing_provenance < total_packages {
recommendations.push(
"Complete provenance information for all packages to ensure full traceability"
.to_string(),
);
}
if recommendations.is_empty() {
recommendations
.push("All packages are compliant. Continue regular audits.".to_string());
}
recommendations
}
}
impl Default for LineageTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LineageStatistics {
pub total_packages: usize,
pub total_edges: usize,
pub total_transformations: usize,
pub packages_with_compliance: usize,
pub root_packages: usize,
pub leaf_packages: usize,
pub average_depth: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lineage_tracker_creation() {
let tracker = LineageTracker::new();
let stats = tracker.get_statistics();
assert_eq!(stats.total_packages, 0);
assert_eq!(stats.total_edges, 0);
}
#[test]
fn test_record_provenance() {
let mut tracker = LineageTracker::new();
let provenance = ProvenanceInfo {
package_id: "test-pkg".to_string(),
creator: "alice@example.com".to_string(),
creation_time: Utc::now(),
source_url: Some("https://github.com/test/repo".to_string()),
source_commit: Some("abc123".to_string()),
build_environment: HashMap::new(),
description: "Test package".to_string(),
};
tracker.record_provenance(provenance);
let retrieved = tracker.get_provenance("test-pkg");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().creator, "alice@example.com");
}
#[test]
fn test_add_lineage() {
let mut tracker = LineageTracker::new();
let prov1 = ProvenanceInfo {
package_id: "base".to_string(),
creator: "alice".to_string(),
creation_time: Utc::now(),
source_url: None,
source_commit: None,
build_environment: HashMap::new(),
description: "Base model".to_string(),
};
tracker.record_provenance(prov1);
let prov2 = ProvenanceInfo {
package_id: "derived".to_string(),
creator: "bob".to_string(),
creation_time: Utc::now(),
source_url: None,
source_commit: None,
build_environment: HashMap::new(),
description: "Derived model".to_string(),
};
tracker.record_provenance(prov2);
let result = tracker.add_lineage(
"base".to_string(),
"derived".to_string(),
LineageRelation::DerivedFrom,
"Fine-tuned version".to_string(),
);
assert!(result.is_ok());
let stats = tracker.get_statistics();
assert_eq!(stats.total_edges, 1);
}
#[test]
fn test_cycle_detection() {
let mut tracker = LineageTracker::new();
tracker
.add_lineage(
"A".to_string(),
"B".to_string(),
LineageRelation::DerivedFrom,
"A to B".to_string(),
)
.unwrap();
tracker
.add_lineage(
"B".to_string(),
"C".to_string(),
LineageRelation::DerivedFrom,
"B to C".to_string(),
)
.unwrap();
let result = tracker.add_lineage(
"C".to_string(),
"A".to_string(),
LineageRelation::DerivedFrom,
"C to A".to_string(),
);
assert!(result.is_err());
}
#[test]
fn test_get_ancestors() {
let mut tracker = LineageTracker::new();
tracker
.add_lineage(
"A".to_string(),
"B".to_string(),
LineageRelation::DerivedFrom,
"".to_string(),
)
.unwrap();
tracker
.add_lineage(
"B".to_string(),
"C".to_string(),
LineageRelation::DerivedFrom,
"".to_string(),
)
.unwrap();
let ancestors = tracker.get_ancestors("C");
assert_eq!(ancestors.len(), 2);
assert!(ancestors.contains(&"B".to_string()));
assert!(ancestors.contains(&"A".to_string()));
}
#[test]
fn test_get_descendants() {
let mut tracker = LineageTracker::new();
tracker
.add_lineage(
"A".to_string(),
"B".to_string(),
LineageRelation::DerivedFrom,
"".to_string(),
)
.unwrap();
tracker
.add_lineage(
"A".to_string(),
"C".to_string(),
LineageRelation::DerivedFrom,
"".to_string(),
)
.unwrap();
let descendants = tracker.get_descendants("A");
assert_eq!(descendants.len(), 2);
assert!(descendants.contains(&"B".to_string()));
assert!(descendants.contains(&"C".to_string()));
}
#[test]
fn test_transformation_recording() {
let mut tracker = LineageTracker::new();
let record = TransformationRecord {
package_id: "test".to_string(),
operation: "quantization".to_string(),
timestamp: Utc::now(),
performed_by: "alice".to_string(),
parameters: [("bits".to_string(), "8".to_string())]
.iter()
.cloned()
.collect(),
result: "success".to_string(),
duration_secs: 120.5,
};
tracker.record_transformation(record);
let transformations = tracker.get_transformations("test");
assert_eq!(transformations.len(), 1);
assert_eq!(transformations[0].operation, "quantization");
}
#[test]
fn test_compliance_metadata() {
let mut tracker = LineageTracker::new();
let metadata = ComplianceMetadata {
package_id: "test".to_string(),
level: ComplianceLevel::Regulatory,
tags: vec!["HIPAA".to_string(), "SOC2".to_string()],
certifications: vec!["ISO27001".to_string()],
data_classification: "Confidential".to_string(),
retention_days: Some(2555),
access_restrictions: vec!["internal-only".to_string()],
audit_required: true,
last_audit: Some(Utc::now()),
next_audit_due: Some(Utc::now() + chrono::Duration::days(90)),
};
tracker.set_compliance(metadata);
let retrieved = tracker.get_compliance("test");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().level, ComplianceLevel::Regulatory);
assert_eq!(retrieved.unwrap().tags.len(), 2);
}
#[test]
fn test_compliance_report() {
let mut tracker = LineageTracker::new();
let metadata = ComplianceMetadata {
package_id: "overdue".to_string(),
level: ComplianceLevel::Regulatory,
tags: vec!["GDPR".to_string()],
certifications: vec![],
data_classification: "Restricted".to_string(),
retention_days: None,
access_restrictions: vec![],
audit_required: true,
last_audit: Some(Utc::now() - chrono::Duration::days(180)),
next_audit_due: Some(Utc::now() - chrono::Duration::days(1)),
};
tracker.set_compliance(metadata);
let report = tracker.generate_compliance_report();
assert_eq!(report.total_packages, 1);
assert!(report.non_compliant_packages > 0);
assert!(!report.issues.is_empty());
}
#[test]
fn test_export_to_dot() {
let mut tracker = LineageTracker::new();
tracker
.add_lineage(
"A".to_string(),
"B".to_string(),
LineageRelation::DerivedFrom,
"".to_string(),
)
.unwrap();
let dot = tracker.export_to_dot("A");
assert!(dot.contains("digraph PackageLineage"));
assert!(dot.contains("\"A\""));
assert!(dot.contains("\"B\""));
assert!(dot.contains("derived from"));
}
#[test]
fn test_export_to_json() {
let mut tracker = LineageTracker::new();
let prov = ProvenanceInfo {
package_id: "A".to_string(),
creator: "alice".to_string(),
creation_time: Utc::now(),
source_url: None,
source_commit: None,
build_environment: HashMap::new(),
description: "Test".to_string(),
};
tracker.record_provenance(prov);
tracker
.add_lineage(
"A".to_string(),
"B".to_string(),
LineageRelation::DerivedFrom,
"".to_string(),
)
.unwrap();
let json = tracker.export_to_json("A");
assert!(json.is_ok());
let json_str = json.unwrap();
assert!(json_str.contains("\"packages\""));
assert!(json_str.contains("\"edges\""));
}
#[test]
fn test_lineage_statistics() {
let mut tracker = LineageTracker::new();
for i in 0..5 {
let prov = ProvenanceInfo {
package_id: format!("pkg-{}", i),
creator: "alice".to_string(),
creation_time: Utc::now(),
source_url: None,
source_commit: None,
build_environment: HashMap::new(),
description: format!("Package {}", i),
};
tracker.record_provenance(prov);
}
tracker
.add_lineage(
"pkg-0".to_string(),
"pkg-1".to_string(),
LineageRelation::DerivedFrom,
"".to_string(),
)
.unwrap();
tracker
.add_lineage(
"pkg-1".to_string(),
"pkg-2".to_string(),
LineageRelation::DerivedFrom,
"".to_string(),
)
.unwrap();
let stats = tracker.get_statistics();
assert_eq!(stats.total_packages, 5);
assert_eq!(stats.total_edges, 2);
assert!(stats.root_packages > 0);
assert!(stats.leaf_packages > 0);
}
}