use crate::error::{RusTorchError, RusTorchResult};
use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
#[derive(Debug)]
pub struct QualityMetrics {
history: VecDeque<DataQualityAssessment>,
max_history_size: usize,
thresholds: MetricThresholds,
aggregated_stats: AggregatedQualityStats,
}
#[derive(Debug, Clone)]
pub struct DataQualityAssessment {
pub overall_score: f64,
pub dimensions: HashMap<QualityDimension, QualityScore>,
pub timestamp: SystemTime,
pub characteristics: DataCharacteristics,
pub trends: Option<QualityTrend>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum QualityDimension {
Completeness,
Accuracy,
Consistency,
Validity,
Uniqueness,
Timeliness,
Integrity,
}
#[derive(Debug, Clone)]
pub struct QualityScore {
pub score: f64,
pub max_score: f64,
pub metrics: HashMap<String, f64>,
pub issues: Vec<QualityIssue>,
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub struct QualityIssue {
pub category: IssueCategory,
pub severity: IssueSeverity,
pub description: String,
pub affected_range: Option<DataRange>,
pub remediation: Option<String>,
pub impact_score: f64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum IssueCategory {
MissingData,
FormatError,
RangeViolation,
Duplication,
Inconsistency,
StalenessIssue,
StatisticalAnomaly,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum IssueSeverity {
Info,
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone)]
pub struct DataRange {
pub start: usize,
pub end: usize,
pub dimension: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct DataCharacteristics {
pub total_points: usize,
pub data_type: String,
pub shape: Vec<usize>,
pub distribution_stats: DistributionStats,
pub memory_footprint: usize,
}
#[derive(Debug, Clone)]
pub struct DistributionStats {
pub mean: f64,
pub std_dev: f64,
pub min: f64,
pub max: f64,
pub percentiles: HashMap<u8, f64>,
pub skewness: f64,
pub kurtosis: f64,
}
#[derive(Debug, Clone)]
pub struct QualityTrend {
pub direction: TrendDirection,
pub strength: f64,
pub change_rate: f64,
pub confidence: f64,
pub prediction: Option<f64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrendDirection {
Improving,
Declining,
Stable,
Volatile,
}
#[derive(Debug, Clone)]
pub struct MetricThresholds {
pub min_overall_score: f64,
pub dimension_thresholds: HashMap<QualityDimension, f64>,
pub max_issues_by_severity: HashMap<IssueSeverity, usize>,
}
impl Default for MetricThresholds {
fn default() -> Self {
let mut dimension_thresholds = HashMap::new();
dimension_thresholds.insert(QualityDimension::Completeness, 0.95);
dimension_thresholds.insert(QualityDimension::Accuracy, 0.9);
dimension_thresholds.insert(QualityDimension::Consistency, 0.9);
dimension_thresholds.insert(QualityDimension::Validity, 0.95);
dimension_thresholds.insert(QualityDimension::Uniqueness, 0.98);
dimension_thresholds.insert(QualityDimension::Timeliness, 0.8);
dimension_thresholds.insert(QualityDimension::Integrity, 0.95);
let mut max_issues = HashMap::new();
max_issues.insert(IssueSeverity::Critical, 0);
max_issues.insert(IssueSeverity::High, 2);
max_issues.insert(IssueSeverity::Medium, 5);
max_issues.insert(IssueSeverity::Low, 10);
max_issues.insert(IssueSeverity::Info, 50);
Self {
min_overall_score: 0.8,
dimension_thresholds,
max_issues_by_severity: max_issues,
}
}
}
#[derive(Debug, Default)]
pub struct AggregatedQualityStats {
pub total_assessments: usize,
pub average_overall_score: f64,
pub best_score: f64,
pub worst_score: f64,
pub score_variance: f64,
pub stability_measure: f64,
}
impl QualityMetrics {
pub fn new() -> Self {
Self {
history: VecDeque::new(),
max_history_size: 1000,
thresholds: MetricThresholds::default(),
aggregated_stats: AggregatedQualityStats::default(),
}
}
pub fn assess_quality<T>(
&mut self,
tensor: &crate::tensor::Tensor<T>,
) -> RusTorchResult<DataQualityAssessment>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let start_time = Instant::now();
let characteristics = self.collect_data_characteristics(tensor);
let mut dimensions = HashMap::new();
dimensions.insert(
QualityDimension::Completeness,
self.assess_completeness(tensor, &characteristics)?,
);
dimensions.insert(
QualityDimension::Accuracy,
self.assess_accuracy(tensor, &characteristics)?,
);
dimensions.insert(
QualityDimension::Consistency,
self.assess_consistency(tensor, &characteristics)?,
);
dimensions.insert(
QualityDimension::Validity,
self.assess_validity(tensor, &characteristics)?,
);
dimensions.insert(
QualityDimension::Uniqueness,
self.assess_uniqueness(tensor, &characteristics)?,
);
dimensions.insert(
QualityDimension::Timeliness,
self.assess_timeliness(&characteristics)?,
);
dimensions.insert(
QualityDimension::Integrity,
self.assess_integrity(tensor, &characteristics)?,
);
let overall_score = self.calculate_overall_score(&dimensions);
let trends = if self.history.len() > 2 {
Some(self.analyze_trends(&overall_score))
} else {
None
};
let assessment = DataQualityAssessment {
overall_score,
dimensions,
timestamp: SystemTime::now(),
characteristics,
trends,
};
self.history.push_back(assessment.clone());
if self.history.len() > self.max_history_size {
self.history.pop_front();
}
self.update_aggregated_stats(&assessment);
println!(
"📊 Quality assessment completed in {:.2}ms, score: {:.3}",
start_time.elapsed().as_secs_f64() * 1000.0,
overall_score
);
Ok(assessment)
}
fn collect_data_characteristics<T>(
&self,
tensor: &crate::tensor::Tensor<T>,
) -> DataCharacteristics
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let shape = tensor.shape();
let total_points = shape.iter().product();
let distribution_stats = DistributionStats {
mean: 0.0,
std_dev: 1.0,
min: -1.0,
max: 1.0,
percentiles: {
let mut percentiles = HashMap::new();
percentiles.insert(25, -0.5);
percentiles.insert(50, 0.0);
percentiles.insert(75, 0.5);
percentiles.insert(95, 0.9);
percentiles.insert(99, 0.99);
percentiles
},
skewness: 0.0,
kurtosis: 3.0,
};
DataCharacteristics {
total_points,
data_type: std::any::type_name::<T>().to_string(),
shape: shape.to_vec(),
distribution_stats,
memory_footprint: total_points * std::mem::size_of::<T>(),
}
}
fn assess_completeness<T>(
&self,
_tensor: &crate::tensor::Tensor<T>,
characteristics: &DataCharacteristics,
) -> RusTorchResult<QualityScore>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let nan_count = 0; let total_points = characteristics.total_points;
let completeness_ratio = if total_points > 0 {
(total_points - nan_count) as f64 / total_points as f64
} else {
1.0
};
let mut metrics = HashMap::new();
metrics.insert("completeness_ratio".to_string(), completeness_ratio);
metrics.insert("missing_values".to_string(), nan_count as f64);
let mut issues = Vec::new();
if completeness_ratio
< self.thresholds.dimension_thresholds[&QualityDimension::Completeness]
{
issues.push(QualityIssue {
category: IssueCategory::MissingData,
severity: IssueSeverity::Medium,
description: format!(
"Completeness ratio {:.3} below threshold {:.3}",
completeness_ratio,
self.thresholds.dimension_thresholds[&QualityDimension::Completeness]
),
affected_range: None,
remediation: Some("Consider imputation or data cleaning".to_string()),
impact_score: 1.0 - completeness_ratio,
});
}
Ok(QualityScore {
score: completeness_ratio,
max_score: 1.0,
metrics,
issues,
confidence: 0.95,
})
}
fn assess_accuracy<T>(
&self,
_tensor: &crate::tensor::Tensor<T>,
characteristics: &DataCharacteristics,
) -> RusTorchResult<QualityScore>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let stats = &characteristics.distribution_stats;
let range_violations = 0; let accuracy_score = 1.0 - (range_violations as f64 / characteristics.total_points as f64);
let mut metrics = HashMap::new();
metrics.insert("accuracy_score".to_string(), accuracy_score);
metrics.insert("range_violations".to_string(), range_violations as f64);
metrics.insert("value_range_width".to_string(), stats.max - stats.min);
Ok(QualityScore {
score: accuracy_score,
max_score: 1.0,
metrics,
issues: Vec::new(),
confidence: 0.9,
})
}
fn assess_consistency<T>(
&self,
_tensor: &crate::tensor::Tensor<T>,
_characteristics: &DataCharacteristics,
) -> RusTorchResult<QualityScore>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let consistency_score = 0.95;
let mut metrics = HashMap::new();
metrics.insert("consistency_score".to_string(), consistency_score);
Ok(QualityScore {
score: consistency_score,
max_score: 1.0,
metrics,
issues: Vec::new(),
confidence: 0.85,
})
}
fn assess_validity<T>(
&self,
_tensor: &crate::tensor::Tensor<T>,
characteristics: &DataCharacteristics,
) -> RusTorchResult<QualityScore>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let has_valid_shape = !characteristics.shape.is_empty();
let validity_score = if has_valid_shape { 1.0 } else { 0.0 };
let mut metrics = HashMap::new();
metrics.insert(
"valid_shape".to_string(),
if has_valid_shape { 1.0 } else { 0.0 },
);
Ok(QualityScore {
score: validity_score,
max_score: 1.0,
metrics,
issues: Vec::new(),
confidence: 1.0,
})
}
fn assess_uniqueness<T>(
&self,
_tensor: &crate::tensor::Tensor<T>,
characteristics: &DataCharacteristics,
) -> RusTorchResult<QualityScore>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let duplicates = 0; let uniqueness_score = if characteristics.total_points > 0 {
1.0 - (duplicates as f64 / characteristics.total_points as f64)
} else {
1.0
};
let mut metrics = HashMap::new();
metrics.insert("uniqueness_score".to_string(), uniqueness_score);
metrics.insert("duplicate_count".to_string(), duplicates as f64);
Ok(QualityScore {
score: uniqueness_score,
max_score: 1.0,
metrics,
issues: Vec::new(),
confidence: 0.8,
})
}
fn assess_timeliness(
&self,
_characteristics: &DataCharacteristics,
) -> RusTorchResult<QualityScore> {
let timeliness_score = 0.9;
let mut metrics = HashMap::new();
metrics.insert("timeliness_score".to_string(), timeliness_score);
Ok(QualityScore {
score: timeliness_score,
max_score: 1.0,
metrics,
issues: Vec::new(),
confidence: 0.7,
})
}
fn assess_integrity<T>(
&self,
_tensor: &crate::tensor::Tensor<T>,
characteristics: &DataCharacteristics,
) -> RusTorchResult<QualityScore>
where
T: num_traits::Float + std::fmt::Debug + Clone + Send + Sync + 'static,
{
let has_valid_structure =
!characteristics.shape.is_empty() && characteristics.total_points > 0;
let integrity_score = if has_valid_structure { 1.0 } else { 0.0 };
let mut metrics = HashMap::new();
metrics.insert(
"structural_integrity".to_string(),
if has_valid_structure { 1.0 } else { 0.0 },
);
Ok(QualityScore {
score: integrity_score,
max_score: 1.0,
metrics,
issues: Vec::new(),
confidence: 0.95,
})
}
fn calculate_overall_score(&self, dimensions: &HashMap<QualityDimension, QualityScore>) -> f64 {
let weights: HashMap<QualityDimension, f64> = [
(QualityDimension::Completeness, 0.2),
(QualityDimension::Accuracy, 0.2),
(QualityDimension::Consistency, 0.15),
(QualityDimension::Validity, 0.15),
(QualityDimension::Uniqueness, 0.1),
(QualityDimension::Timeliness, 0.1),
(QualityDimension::Integrity, 0.1),
]
.iter()
.cloned()
.collect();
let mut weighted_sum = 0.0;
let mut total_weight = 0.0;
for (dimension, score) in dimensions {
if let Some(&weight) = weights.get(dimension) {
weighted_sum += score.score * weight;
total_weight += weight;
}
}
if total_weight > 0.0 {
weighted_sum / total_weight
} else {
0.0
}
}
fn analyze_trends(&self, _current_score: &f64) -> QualityTrend {
QualityTrend {
direction: TrendDirection::Stable,
strength: 0.1,
change_rate: 0.001,
confidence: 0.7,
prediction: None,
}
}
fn update_aggregated_stats(&mut self, assessment: &DataQualityAssessment) {
self.aggregated_stats.total_assessments += 1;
let count = self.aggregated_stats.total_assessments as f64;
let old_mean = self.aggregated_stats.average_overall_score;
let new_score = assessment.overall_score;
self.aggregated_stats.average_overall_score = old_mean + (new_score - old_mean) / count;
if self.aggregated_stats.total_assessments == 1 {
self.aggregated_stats.best_score = new_score;
self.aggregated_stats.worst_score = new_score;
} else {
self.aggregated_stats.best_score = self.aggregated_stats.best_score.max(new_score);
self.aggregated_stats.worst_score = self.aggregated_stats.worst_score.min(new_score);
}
let variance_delta =
(new_score - old_mean) * (new_score - self.aggregated_stats.average_overall_score);
self.aggregated_stats.score_variance =
(self.aggregated_stats.score_variance * (count - 1.0) + variance_delta) / count;
self.aggregated_stats.stability_measure = 1.0 - self.aggregated_stats.score_variance.sqrt();
}
pub fn get_history(&self) -> &VecDeque<DataQualityAssessment> {
&self.history
}
pub fn get_aggregated_stats(&self) -> &AggregatedQualityStats {
&self.aggregated_stats
}
}
impl DataQualityAssessment {
pub fn quality_grade(&self) -> &str {
match self.overall_score {
s if s >= 0.95 => "A+",
s if s >= 0.9 => "A",
s if s >= 0.85 => "A-",
s if s >= 0.8 => "B+",
s if s >= 0.75 => "B",
s if s >= 0.7 => "B-",
s if s >= 0.65 => "C+",
s if s >= 0.6 => "C",
s if s >= 0.5 => "C-",
s if s >= 0.4 => "D",
_ => "F",
}
}
}
impl fmt::Display for DataQualityAssessment {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"📊 Data Quality Assessment\n\
==========================\n\
Overall Score: {:.3} (Grade: {})\n\
Total Points: {}\n\
Dimensions Assessed: {}\n\
Timestamp: {:?}",
self.overall_score,
self.quality_grade(),
self.characteristics.total_points,
self.dimensions.len(),
self.timestamp
)
}
}