use crate::error::{RusTorchError, RusTorchResult};
use std::collections::HashMap;
use std::fmt;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct ValidationEngine {
config: ValidationConfig,
rules: Vec<Box<dyn ValidationRule>>,
stats: ValidationStatistics,
}
#[derive(Debug, Clone)]
pub struct ValidationConfig {
pub strict_mode: bool,
pub max_nan_percentage: f64,
pub max_inf_percentage: f64,
pub min_finite_percentage: f64,
pub performance_budget_us: u64,
pub enable_schema_validation: bool,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
strict_mode: false,
max_nan_percentage: 0.01, max_inf_percentage: 0.001, min_finite_percentage: 0.95, performance_budget_us: 500, enable_schema_validation: true,
}
}
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub is_valid: bool,
pub level: ValidationLevel,
pub issues: Vec<ValidationIssue>,
pub metrics: ValidationMetrics,
pub validation_time: Duration,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidationLevel {
Basic,
Standard,
Comprehensive,
Strict,
}
#[derive(Debug, Clone)]
pub struct ValidationIssue {
pub issue_type: IssueType,
pub severity: IssueSeverity,
pub message: String,
pub context: HashMap<String, String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IssueType {
InvalidShape,
NaNValues,
InfiniteValues,
OutOfRange,
TypeMismatch,
SchemaViolation,
PerformanceIssue,
MemoryConstraint,
CustomValidation,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum IssueSeverity {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone)]
pub struct ValidationMetrics {
pub total_elements: usize,
pub nan_count: usize,
pub inf_count: usize,
pub finite_count: usize,
pub value_range: Option<(f64, f64)>, pub memory_usage_bytes: usize,
pub performance_metrics: PerformanceMetrics,
}
#[derive(Debug, Clone)]
pub struct PerformanceMetrics {
pub elements_per_second: f64,
pub memory_throughput_mb_per_sec: f64,
pub cache_hit_rate: f64,
}
#[derive(Debug, Clone)]
pub struct DataSchema {
pub expected_shape: Option<Vec<usize>>,
pub expected_dtype: String,
pub value_constraints: ValueConstraints,
pub custom_rules: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ValueConstraints {
pub min_value: Option<f64>,
pub max_value: Option<f64>,
pub allow_nan: bool,
pub allow_infinite: bool,
pub statistical_constraints: Option<StatisticalConstraints>,
}
#[derive(Debug, Clone)]
pub struct StatisticalConstraints {
pub mean_range: Option<(f64, f64)>,
pub std_dev_range: Option<(f64, f64)>,
pub distribution_type: Option<DistributionType>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DistributionType {
Normal,
Uniform,
Exponential,
Custom(String),
}
pub trait ValidationRule: fmt::Debug + Send + Sync {
fn name(&self) -> &str;
fn validate_f32(
&self,
tensor: &crate::tensor::Tensor<f32>,
) -> RusTorchResult<Vec<ValidationIssue>>;
fn validate_f64(
&self,
tensor: &crate::tensor::Tensor<f64>,
) -> RusTorchResult<Vec<ValidationIssue>>;
}
#[derive(Debug)]
pub struct SchemaValidation {
schema: DataSchema,
}
impl SchemaValidation {
pub fn new(schema: DataSchema) -> Self {
Self { schema }
}
}
impl ValidationRule for SchemaValidation {
fn name(&self) -> &str {
"schema_validation"
}
fn validate_f32(
&self,
tensor: &crate::tensor::Tensor<f32>,
) -> RusTorchResult<Vec<ValidationIssue>> {
self.validate_tensor_generic(tensor)
}
fn validate_f64(
&self,
tensor: &crate::tensor::Tensor<f64>,
) -> RusTorchResult<Vec<ValidationIssue>> {
self.validate_tensor_generic(tensor)
}
}
impl SchemaValidation {
fn validate_tensor_generic<T>(
&self,
tensor: &crate::tensor::Tensor<T>,
) -> RusTorchResult<Vec<ValidationIssue>>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let mut issues = Vec::new();
if let Some(ref expected_shape) = self.schema.expected_shape {
if &tensor.shape() != expected_shape {
issues.push(ValidationIssue {
issue_type: IssueType::InvalidShape,
severity: IssueSeverity::High,
message: format!(
"Shape mismatch: expected {:?}, got {:?}",
expected_shape,
tensor.shape()
),
context: {
let mut ctx = HashMap::new();
ctx.insert("expected".to_string(), format!("{:?}", expected_shape));
ctx.insert("actual".to_string(), format!("{:?}", tensor.shape()));
ctx
},
});
}
}
if !self.schema.value_constraints.allow_nan {
let nan_count = 0; if nan_count > 0 {
issues.push(ValidationIssue {
issue_type: IssueType::NaNValues,
severity: IssueSeverity::Medium,
message: format!("Found {} NaN values (not allowed by schema)", nan_count),
context: {
let mut ctx = HashMap::new();
ctx.insert("nan_count".to_string(), nan_count.to_string());
ctx
},
});
}
}
Ok(issues)
}
}
#[derive(Debug, Default)]
pub struct ValidationStatistics {
pub total_validations: usize,
pub successful_validations: usize,
pub failed_validations: usize,
pub average_validation_time: Duration,
pub total_validation_time: Duration,
}
impl ValidationEngine {
pub fn new(config: ValidationConfig) -> RusTorchResult<Self> {
Ok(Self {
config,
rules: Vec::new(),
stats: ValidationStatistics::default(),
})
}
pub fn add_rule(&mut self, rule: Box<dyn ValidationRule>) {
self.rules.push(rule);
}
pub fn validate_tensor<T>(
&mut self,
tensor: &crate::tensor::Tensor<T>,
) -> RusTorchResult<ValidationResult>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let start_time = Instant::now();
let mut issues = Vec::new();
let level = if self.config.strict_mode {
ValidationLevel::Strict
} else {
ValidationLevel::Standard
};
let shape = tensor.shape();
if shape.is_empty() {
issues.push(ValidationIssue {
issue_type: IssueType::InvalidShape,
severity: IssueSeverity::High,
message: "Empty tensor shape detected".to_string(),
context: HashMap::new(),
});
}
let total_elements = shape.iter().product();
let metrics = ValidationMetrics {
total_elements,
nan_count: 0, inf_count: 0, finite_count: total_elements, value_range: Some((0.0, 1.0)), memory_usage_bytes: total_elements * std::mem::size_of::<T>(),
performance_metrics: PerformanceMetrics {
elements_per_second: total_elements as f64
/ start_time.elapsed().as_secs_f64().max(1e-9),
memory_throughput_mb_per_sec: 0.0, cache_hit_rate: 0.0, },
};
use std::any::{Any, TypeId};
let tensor_any = tensor as &dyn Any;
for rule in &self.rules {
let rule_result = if TypeId::of::<T>() == TypeId::of::<f32>() {
if let Some(f32_tensor) = tensor_any.downcast_ref::<crate::tensor::Tensor<f32>>() {
rule.validate_f32(f32_tensor)
} else {
continue;
}
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
if let Some(f64_tensor) = tensor_any.downcast_ref::<crate::tensor::Tensor<f64>>() {
rule.validate_f64(f64_tensor)
} else {
continue;
}
} else {
continue;
};
match rule_result {
Ok(mut rule_issues) => issues.append(&mut rule_issues),
Err(e) => {
issues.push(ValidationIssue {
issue_type: IssueType::CustomValidation,
severity: IssueSeverity::High,
message: format!("Custom validation rule '{}' failed: {}", rule.name(), e),
context: HashMap::new(),
});
}
}
}
let validation_time = start_time.elapsed();
if validation_time.as_micros() as u64 > self.config.performance_budget_us {
issues.push(ValidationIssue {
issue_type: IssueType::PerformanceIssue,
severity: IssueSeverity::Medium,
message: format!(
"Validation exceeded performance budget: {}μs > {}μs",
validation_time.as_micros(),
self.config.performance_budget_us
),
context: {
let mut ctx = HashMap::new();
ctx.insert(
"actual_time_us".to_string(),
validation_time.as_micros().to_string(),
);
ctx.insert(
"budget_us".to_string(),
self.config.performance_budget_us.to_string(),
);
ctx
},
});
}
let is_valid = match level {
ValidationLevel::Strict => issues.is_empty(),
_ => !issues
.iter()
.any(|issue| issue.severity >= IssueSeverity::High),
};
self.stats.total_validations += 1;
if is_valid {
self.stats.successful_validations += 1;
} else {
self.stats.failed_validations += 1;
}
self.stats.total_validation_time += validation_time;
self.stats.average_validation_time =
self.stats.total_validation_time / self.stats.total_validations as u32;
Ok(ValidationResult {
is_valid,
level,
issues,
metrics,
validation_time,
})
}
pub fn get_statistics(&self) -> &ValidationStatistics {
&self.stats
}
pub fn reset_statistics(&mut self) {
self.stats = ValidationStatistics::default();
}
}
impl fmt::Display for ValidationResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let status = if self.is_valid {
"✅ VALID"
} else {
"❌ INVALID"
};
write!(
f,
"🔍 Validation Result\n\
===================\n\
Status: {}\n\
Level: {:?}\n\
Issues: {}\n\
Elements: {}\n\
Time: {:.3}ms",
status,
self.level,
self.issues.len(),
self.metrics.total_elements,
self.validation_time.as_secs_f64() * 1000.0
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validation_config_default() {
let config = ValidationConfig::default();
assert!(!config.strict_mode);
assert_eq!(config.max_nan_percentage, 0.01);
assert_eq!(config.performance_budget_us, 500);
}
#[test]
fn test_validation_engine_creation() {
let config = ValidationConfig::default();
let result = ValidationEngine::new(config);
assert!(result.is_ok());
}
#[test]
fn test_validation_issue_creation() {
let issue = ValidationIssue {
issue_type: IssueType::NaNValues,
severity: IssueSeverity::Medium,
message: "Test issue".to_string(),
context: HashMap::new(),
};
assert_eq!(issue.issue_type, IssueType::NaNValues);
assert_eq!(issue.severity, IssueSeverity::Medium);
}
}