use crate::provenance::{
OnnxNodeInfo, PROVENANCE_TRACKER, PassName, ProvenanceEvent, ProvenanceTracker, SourceLocation,
get_relative_location,
};
use crate::uop::UOp;
use std::f32::consts::PI;
use std::panic::Location;
#[test]
fn test_basic_provenance_capture() {
let uop = UOp::native_const(42i32);
PROVENANCE_TRACKER.with(|tracker| {
let tracker = tracker.borrow();
let events = tracker.get_events(uop.id);
assert!(events.is_some(), "Provenance should be captured for new UOp");
let events = events.unwrap();
assert_eq!(events.len(), 1, "Should have one Created event");
match &events[0] {
ProvenanceEvent::Created { location } => {
assert!(!location.file.is_empty());
}
_ => panic!("Expected Created event"),
}
});
}
#[test]
fn test_transformation_tracking() {
let a = UOp::native_const(1i32);
let b = UOp::native_const(2i32);
let c = a.try_add(&b).unwrap();
PROVENANCE_TRACKER.with(|tracker| {
let tracker = tracker.borrow();
assert!(tracker.get_events(a.id).is_some());
assert!(tracker.get_events(b.id).is_some());
assert!(tracker.get_events(c.id).is_some(), "Result should have provenance");
});
}
#[test]
fn test_substitute_transformation() {
use crate::UOpKey;
use std::collections::HashMap;
let original = UOp::native_const(10i32);
let replacement = UOp::native_const(20i32);
#[allow(clippy::mutable_key_type)]
let mut subst_map = HashMap::new();
subst_map.insert(UOpKey(original.clone()), replacement.clone());
let result = original.substitute(&subst_map);
assert_eq!(result.id, replacement.id);
PROVENANCE_TRACKER.with(|tracker| {
let binding = tracker.borrow();
let events = binding.get_events(result.id).unwrap();
let has_transform = events.iter().any(|e| matches!(e, ProvenanceEvent::Transformed { .. }));
assert!(has_transform, "Should have Transformed event");
});
}
#[test]
fn test_provenance_chain() {
use crate::UOpKey;
use std::collections::HashMap;
let uop1 = UOp::native_const(1i32);
let uop2 = UOp::native_const(2i32);
#[allow(clippy::mutable_key_type)]
let mut subst_map = HashMap::new();
subst_map.insert(UOpKey(uop1.clone()), uop2.clone());
let result1 = uop1.substitute(&subst_map);
let uop3 = UOp::native_const(3i32);
#[allow(clippy::mutable_key_type)]
let mut subst_map2 = HashMap::new();
subst_map2.insert(UOpKey(result1.clone()), uop3.clone());
let result2 = result1.substitute(&subst_map2);
PROVENANCE_TRACKER.with(|tracker| {
let chain = tracker.borrow().get_chain(result2.id);
assert!(!chain.is_empty(), "Should have provenance chain");
let has_transforms = chain.iter().any(|e| matches!(e, ProvenanceEvent::Transformed { .. }));
assert!(has_transforms, "Chain should include transformations");
});
}
#[test]
fn test_onnx_node_attachment() {
let uop = UOp::native_const(PI);
let onnx_node = OnnxNodeInfo {
name: Some("conv1".to_string()),
op_type: "Conv".to_string(),
domain: "ai.onnx".to_string(),
version: 11,
};
PROVENANCE_TRACKER.with(|tracker| {
tracker.borrow_mut().attach_onnx_node(uop.id, onnx_node.clone());
});
PROVENANCE_TRACKER.with(|tracker| {
let binding = tracker.borrow();
let events = binding.get_events(uop.id).unwrap();
let has_onnx = events.iter().any(|e| match e {
ProvenanceEvent::FromOnnx { node, .. } => node.op_type == "Conv" && node.name == Some("conv1".to_string()),
_ => false,
});
assert!(has_onnx, "Should have ONNX event");
});
}
#[test]
fn test_source_location_display() {
let loc = SourceLocation::new("tensor/src/ops.rs", 42, 10);
let display = loc.to_string();
assert!(display.contains("tensor/src/ops.rs"));
assert!(display.contains("42"));
assert!(display.contains("10"));
}
#[test]
fn test_onnx_node_info_display() {
let node = OnnxNodeInfo {
name: Some("layer1".to_string()),
op_type: "Add".to_string(),
domain: "ai.onnx".to_string(),
version: 13,
};
let display = node.to_string();
assert!(display.contains("Add"));
assert!(display.contains("layer1"));
assert!(display.contains("13"));
let node_no_name =
OnnxNodeInfo { name: None, op_type: "Mul".to_string(), domain: "ai.onnx".to_string(), version: 13 };
let display = node_no_name.to_string();
assert!(display.contains("Mul"));
assert!(!display.contains("layer"));
}
#[test]
fn test_provenance_event_display() {
let loc = SourceLocation::new("test.rs", 10, 5);
let created = ProvenanceEvent::Created { location: loc };
let display = created.to_string();
assert!(display.contains("Created"));
assert!(display.contains("test.rs"));
let transformed = ProvenanceEvent::Transformed { from_id: 1, pass_name: PassName::Substitute };
let display = transformed.to_string();
assert!(display.contains("substitute"));
assert!(display.contains("UOp 1")); }
#[test]
fn test_tracker_cleanup() {
use std::collections::HashSet;
let mut tracker = ProvenanceTracker::default();
let loc = Location::caller();
for i in 0..100 {
tracker.capture(i, loc);
}
assert_eq!(tracker.len(), 100);
let live_set: HashSet<u64> = (0..100).filter(|i| i % 2 == 0).collect();
assert_eq!(live_set.len(), 50);
tracker.cleanup_with_live_set(&live_set);
assert_eq!(tracker.len(), 50);
for i in 0..100 {
if i % 2 == 0 {
assert!(tracker.get_events(i).is_some(), "Even ID {} should still exist", i);
} else {
assert!(tracker.get_events(i).is_none(), "Odd ID {} should be removed", i);
}
}
tracker.clear();
assert_eq!(tracker.len(), 0);
assert!(tracker.is_empty());
}
#[test]
fn test_multiple_parents() {
use crate::UOpKey;
use std::collections::HashMap;
let a = UOp::native_const(1i32);
let b = UOp::native_const(2i32);
let c = a.try_add(&b).unwrap();
let a_new = UOp::native_const(10i32);
#[allow(clippy::mutable_key_type)]
let mut subst_map = HashMap::new();
subst_map.insert(UOpKey(a.clone()), a_new.clone());
let c_new = c.substitute(&subst_map);
PROVENANCE_TRACKER.with(|tracker| {
let binding = tracker.borrow();
let events = binding.get_events(c_new.id);
assert!(events.is_some(), "Result should have provenance");
let events = events.unwrap();
let has_transforms = events.iter().any(|e| matches!(e, ProvenanceEvent::Transformed { .. }));
assert!(has_transforms, "Should have transformations");
});
}
#[test]
fn test_format_chain() {
use crate::provenance::format_chain;
let mut tracker = ProvenanceTracker::default();
let loc = Location::caller();
tracker.capture(1, loc);
tracker.record_transform(2, 1, PassName::Substitute);
let chain = tracker.get_chain(2);
let formatted = format_chain(&chain);
assert!(!formatted.is_empty());
assert!(formatted.contains("substitute"));
}
#[test]
#[cfg(feature = "serde")]
fn test_provenance_serialization() {
let loc = SourceLocation::new("test.rs", 100, 20);
let event = ProvenanceEvent::Created { location: loc };
let serialized = serde_json::to_string(&event).expect("Serialization should succeed");
assert!(serialized.contains("test.rs"));
assert!(serialized.contains("100"));
let deserialized: ProvenanceEvent = serde_json::from_str(&serialized).expect("Deserialization should succeed");
assert_eq!(event, deserialized);
}
#[test]
#[cfg(feature = "serde")]
fn test_onnx_node_serialization() {
let node = OnnxNodeInfo {
name: Some("test_node".to_string()),
op_type: "Conv".to_string(),
domain: "ai.onnx".to_string(),
version: 11,
};
let event = ProvenanceEvent::FromOnnx { node };
let serialized = serde_json::to_string(&event).expect("Serialization should succeed");
let deserialized: ProvenanceEvent = serde_json::from_str(&serialized).expect("Deserialization should succeed");
assert_eq!(event, deserialized);
}
#[test]
#[cfg(feature = "serde")]
fn test_transformed_event_serialization() {
let event = ProvenanceEvent::Transformed { from_id: 1, pass_name: PassName::Substitute };
let serialized = serde_json::to_string(&event).expect("Serialization should succeed");
let deserialized: ProvenanceEvent = serde_json::from_str(&serialized).expect("Deserialization should succeed");
assert_eq!(event, deserialized);
}
#[test]
fn test_error_provenance_logging() {
use crate::error::{Error, log_provenance};
let uop = UOp::native_const(42i32);
let error = Error::DivisionByZero;
log_provenance(uop.id, &error);
log_provenance(99999, &error);
}
#[test]
fn test_get_relative_location() {
let loc = Location::caller();
let relative = get_relative_location(loc);
assert!(relative.starts_with("ir/"), "Expected relative path starting with 'ir/', got: {}", relative);
assert!(relative.contains("provenance.rs"), "Expected path to contain 'provenance.rs', got: {}", relative);
}