use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationConfig {
pub numerical_tolerance: f32,
pub run_performance_tests: bool,
pub compare_with_reference: bool,
pub max_inference_time_ms: u64,
pub max_memory_usage_mb: u64,
pub test_inputs: Vec<TestInputConfig>,
pub test_data_types: Vec<TestDataType>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestInputConfig {
pub name: String,
pub dimensions: Vec<usize>,
pub data_type: TestDataType,
pub required: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum TestDataType {
F32,
F16,
I32,
I64,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
numerical_tolerance: 1e-4,
run_performance_tests: true,
compare_with_reference: false,
max_inference_time_ms: 10000,
max_memory_usage_mb: 16384,
test_inputs: vec![
TestInputConfig {
name: "small_batch".to_string(),
dimensions: vec![1, 128],
data_type: TestDataType::I32,
required: true,
},
TestInputConfig {
name: "medium_batch".to_string(),
dimensions: vec![4, 256],
data_type: TestDataType::I32,
required: true,
},
TestInputConfig {
name: "large_batch".to_string(),
dimensions: vec![16, 512],
data_type: TestDataType::I32,
required: false,
},
],
test_data_types: vec![TestDataType::F32, TestDataType::F16],
}
}
}