pub mod config;
pub mod fairness;
#[cfg(test)]
mod fairness_tests;
pub mod model_test_suite;
#[cfg(test)]
mod model_test_suite_tests;
pub mod performance;
pub mod reference_comparison;
pub mod reporting;
pub mod types;
pub use config::{TestDataType, TestInputConfig, ValidationConfig};
pub use fairness::{
BiasMetric, BiasmitigationStrategy, FairnessAssessment, FairnessConfig, FairnessMetricType,
FairnessResult, FairnessTestData, FairnessViolation, GroupData, StatisticalTest,
};
pub use model_test_suite::ModelTestSuite;
pub use performance::PerformanceProfiler;
pub use reference_comparison::ReferenceComparator;
pub use reporting::{generate_test_report, save_report_to_file};
pub use types::{
LayerPerformance, MemoryAnalysis, NumericalDifferences, NumericalParityResults,
OverallPerformance, PerformanceResults, TestResult, TestStatistics, ThroughputMeasurements,
TimingInfo,
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validation_config_default() {
let config = ValidationConfig::default();
assert_eq!(config.numerical_tolerance, 1e-4);
assert!(config.run_performance_tests);
assert!(!config.compare_with_reference);
assert_eq!(config.test_inputs.len(), 3);
}
#[test]
fn test_model_test_suite_creation() {
let _test_suite = ModelTestSuite::new("test-model");
}
#[test]
fn test_performance_profiler_creation() {
let _profiler = PerformanceProfiler::new();
}
#[test]
fn test_reference_comparator() {
let comparator = ReferenceComparator::new(1e-3);
let _ = comparator.tolerance();
}
#[test]
fn test_numerical_differences_validation() {
let comparator = ReferenceComparator::new(1e-3);
let good_diffs = NumericalDifferences {
max_abs_diff: 1e-4,
mean_abs_diff: 1e-5,
rms_diff: 1e-5,
within_tolerance_percent: 99.0,
};
assert!(comparator.validate_differences(&good_diffs));
let bad_diffs = NumericalDifferences {
max_abs_diff: 1e-2,
mean_abs_diff: 1e-3,
rms_diff: 1e-3,
within_tolerance_percent: 90.0,
};
assert!(!comparator.validate_differences(&bad_diffs));
}
#[test]
fn test_test_statistics_calculation() {
let stats = TestStatistics {
total_tests: 10,
passed_tests: 8,
failed_tests: 2,
pass_rate: 80.0,
};
assert_eq!(stats.pass_rate, 80.0);
assert_eq!(stats.total_tests, stats.passed_tests + stats.failed_tests);
}
}