use super::*;
#[test]
fn test_feature_change_new() {
let change = FeatureChange::new(0, 1.0, 2.0);
assert_eq!(change.feature_idx, 0);
assert_eq!(change.original_value, 1.0);
assert_eq!(change.counterfactual_value, 2.0);
assert_eq!(change.delta, 1.0);
}
#[test]
fn test_feature_change_with_name() {
let change = FeatureChange::new(0, 1.0, 2.0).with_name("income");
assert_eq!(change.feature_name, Some("income".to_string()));
}
#[test]
fn test_feature_change_abs_delta() {
let change_pos = FeatureChange::new(0, 1.0, 2.0);
let change_neg = FeatureChange::new(0, 2.0, 1.0);
assert_eq!(change_pos.abs_delta(), 1.0);
assert_eq!(change_neg.abs_delta(), 1.0);
}
#[test]
fn test_counterfactual_new() {
let cf = Counterfactual::new(vec![1.0, 2.0, 3.0], 0, 0.9, vec![1.5, 2.0, 4.0], 1, 0.85);
assert_eq!(cf.original_decision, 0);
assert_eq!(cf.alternative_decision, 1);
assert_eq!(cf.n_changes(), 2); assert!(cf.is_valid());
}
#[test]
fn test_counterfactual_metrics() {
let cf = Counterfactual::new(vec![0.0, 0.0], 0, 0.9, vec![3.0, 4.0], 1, 0.85);
assert!((cf.sparsity - 7.0).abs() < 1e-6);
assert!((cf.distance - 5.0).abs() < 1e-6);
}
#[test]
fn test_counterfactual_explain() {
let cf = Counterfactual::new(vec![45000.0, 0.42], 0, 0.7, vec![52000.0, 0.35], 1, 0.8)
.with_feature_names(&["income".to_string(), "debt_ratio".to_string()]);
let explanation = cf.explain();
assert!(explanation.contains("Original decision: 0"));
assert!(explanation.contains("Alternative decision: 1"));
assert!(explanation.contains("income"));
assert!(explanation.contains("debt_ratio"));
}
#[test]
fn test_counterfactual_serialization_roundtrip() {
let cf = Counterfactual::new(vec![1.0, 2.0, 3.0], 0, 0.9, vec![1.5, 2.0, 4.0], 1, 0.85)
.with_feature_names(&[
"feature_a".to_string(),
"feature_b".to_string(),
"feature_c".to_string(),
]);
let bytes = cf.to_bytes();
let restored = Counterfactual::from_bytes(&bytes).expect("Failed to deserialize");
assert_eq!(cf.original_decision, restored.original_decision);
assert_eq!(cf.alternative_decision, restored.alternative_decision);
assert!((cf.original_confidence - restored.original_confidence).abs() < 1e-6);
assert_eq!(cf.original_input.len(), restored.original_input.len());
assert_eq!(cf.changes.len(), restored.changes.len());
assert!((cf.sparsity - restored.sparsity).abs() < 1e-6);
assert!((cf.distance - restored.distance).abs() < 1e-6);
}
#[test]
fn test_counterfactual_no_changes() {
let cf = Counterfactual::new(
vec![1.0, 2.0, 3.0],
0,
0.9,
vec![1.0, 2.0, 3.0], 0, 0.9,
);
assert_eq!(cf.n_changes(), 0);
assert!(!cf.is_valid()); }
#[test]
fn test_counterfactual_error_display() {
let err = CounterfactualError::InsufficientData { expected: 100, actual: 50 };
assert!(err.to_string().contains("expected 100"));
let err = CounterfactualError::VersionMismatch { expected: 1, actual: 2 };
assert!(err.to_string().contains("Version mismatch"));
}
#[test]
fn test_counterfactual_insufficient_data() {
let result = Counterfactual::from_bytes(&[0; 10]);
assert!(matches!(result, Err(CounterfactualError::InsufficientData { .. })));
}
#[test]
fn test_counterfactual_version_mismatch() {
let mut bytes = vec![2u8]; bytes.extend_from_slice(&0u32.to_le_bytes());
bytes.extend_from_slice(&0.0f32.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes());
bytes.extend_from_slice(&0.0f32.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes());
let result = Counterfactual::from_bytes(&bytes);
assert!(matches!(result, Err(CounterfactualError::VersionMismatch { .. })));
}