use crate::error::{RusTorchError, RusTorchResult};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug)]
pub struct ConsistencyChecker {
rules: Vec<Box<dyn ConsistencyRule>>,
stats: ConsistencyStatistics,
}
pub trait ConsistencyRule: fmt::Debug + Send + Sync {
fn name(&self) -> &str;
fn check_f32(
&self,
tensor: &crate::tensor::Tensor<f32>,
) -> RusTorchResult<Vec<ConsistencyViolation>>;
fn check_f64(
&self,
tensor: &crate::tensor::Tensor<f64>,
) -> RusTorchResult<Vec<ConsistencyViolation>>;
}
#[derive(Debug, Clone)]
pub struct ConsistencyResult {
pub is_consistent: bool,
pub violations: Vec<ConsistencyViolation>,
pub consistency_score: f64,
}
#[derive(Debug, Clone)]
pub struct ConsistencyViolation {
pub rule_name: String,
pub severity: ViolationSeverity,
pub description: String,
pub location: Option<DataLocation>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ViolationSeverity {
Minor,
Moderate,
Major,
Critical,
}
#[derive(Debug, Clone)]
pub struct DataLocation {
pub indices: Vec<usize>,
pub range: Option<(usize, usize)>,
}
pub struct DataConsistency;
impl DataConsistency {
pub fn check_shape_consistency<T>(tensor: &crate::tensor::Tensor<T>) -> bool
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let shape = tensor.shape();
!shape.is_empty() && shape.iter().all(|&dim| dim > 0)
}
pub fn check_value_range_consistency<T>(
_tensor: &crate::tensor::Tensor<T>,
_min: T,
_max: T,
) -> bool
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
true
}
}
pub struct ReferentialIntegrity;
impl ReferentialIntegrity {
pub fn check_referential_integrity<T>(
_primary: &crate::tensor::Tensor<T>,
_foreign: &crate::tensor::Tensor<T>,
) -> bool
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
true
}
}
#[derive(Debug)]
pub struct ShapeConsistencyRule;
impl ConsistencyRule for ShapeConsistencyRule {
fn name(&self) -> &str {
"shape_consistency"
}
fn check_f32(
&self,
tensor: &crate::tensor::Tensor<f32>,
) -> RusTorchResult<Vec<ConsistencyViolation>> {
self.check_tensor_generic(tensor)
}
fn check_f64(
&self,
tensor: &crate::tensor::Tensor<f64>,
) -> RusTorchResult<Vec<ConsistencyViolation>> {
self.check_tensor_generic(tensor)
}
}
impl ShapeConsistencyRule {
fn check_tensor_generic<T>(
&self,
tensor: &crate::tensor::Tensor<T>,
) -> RusTorchResult<Vec<ConsistencyViolation>>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let mut violations = Vec::new();
if !DataConsistency::check_shape_consistency(tensor) {
violations.push(ConsistencyViolation {
rule_name: self.name().to_string(),
severity: ViolationSeverity::Major,
description: "Invalid tensor shape detected".to_string(),
location: None,
});
}
Ok(violations)
}
}
#[derive(Debug, Default)]
pub struct ConsistencyStatistics {
pub total_checks: usize,
pub total_violations: usize,
pub violations_by_severity: HashMap<ViolationSeverity, usize>,
}
impl ConsistencyChecker {
pub fn new() -> Self {
let mut checker = Self {
rules: Vec::new(),
stats: ConsistencyStatistics::default(),
};
checker.add_rule(Box::new(ShapeConsistencyRule));
checker
}
pub fn add_rule(&mut self, rule: Box<dyn ConsistencyRule>) {
self.rules.push(rule);
}
pub fn check_consistency<T>(
&mut self,
tensor: &crate::tensor::Tensor<T>,
) -> RusTorchResult<ConsistencyResult>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let mut all_violations = Vec::new();
use std::any::{Any, TypeId};
let tensor_any = tensor as &dyn Any;
for rule in &self.rules {
let rule_result = if TypeId::of::<T>() == TypeId::of::<f32>() {
if let Some(f32_tensor) = tensor_any.downcast_ref::<crate::tensor::Tensor<f32>>() {
rule.check_f32(f32_tensor)
} else {
continue;
}
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
if let Some(f64_tensor) = tensor_any.downcast_ref::<crate::tensor::Tensor<f64>>() {
rule.check_f64(f64_tensor)
} else {
continue;
}
} else {
continue;
};
match rule_result {
Ok(mut violations) => all_violations.append(&mut violations),
Err(e) => {
all_violations.push(ConsistencyViolation {
rule_name: rule.name().to_string(),
severity: ViolationSeverity::Critical,
description: format!("Rule execution failed: {}", e),
location: None,
});
}
}
}
self.stats.total_checks += 1;
self.stats.total_violations += all_violations.len();
for violation in &all_violations {
*self
.stats
.violations_by_severity
.entry(violation.severity.clone())
.or_insert(0) += 1;
}
let is_consistent = all_violations.is_empty();
let consistency_score = if is_consistent {
1.0
} else {
1.0 - (all_violations.len() as f64 / 10.0).min(1.0)
};
Ok(ConsistencyResult {
is_consistent,
violations: all_violations,
consistency_score,
})
}
pub fn get_violation_count(&self) -> usize {
self.stats.total_violations
}
}