#![allow(dead_code)]
use anyhow::Result;
use tracing::{debug, info, warn};
use scirs2_core::ndarray::Array1;
use scirs2_core::random::{thread_rng, Distribution, Normal, Uniform};
use torsh::core::device::DeviceType;
use torsh::tensor::Tensor;
use super::types::{LayerInfo, TorshModel};
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub passed: bool,
pub accuracy: Option<f64>,
pub top5_accuracy: Option<f64>,
pub num_samples: usize,
pub successful_inferences: usize,
pub failed_inferences: usize,
pub avg_inference_time_ms: f64,
pub peak_memory_mb: f64,
pub gradient_check_passed: Option<bool>,
pub numerical_stability: f64,
pub errors: Vec<String>,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct GradientCheckResult {
pub passed: bool,
pub max_relative_error: f64,
pub avg_relative_error: f64,
pub num_gradients_checked: usize,
pub failed_locations: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct StabilityAnalysis {
pub has_nan: bool,
pub has_inf: bool,
pub has_large_values: bool,
pub has_tiny_values: bool,
pub gradient_magnitude: GradientStatistics,
pub activation_stats: ActivationStatistics,
}
#[derive(Debug, Clone)]
pub struct GradientStatistics {
pub mean: f64,
pub std: f64,
pub min: f64,
pub max: f64,
pub vanishing_percentage: f64,
pub exploding_percentage: f64,
}
#[derive(Debug, Clone)]
pub struct ActivationStatistics {
pub mean: f64,
pub std: f64,
pub min: f64,
pub max: f64,
pub dead_neurons_percentage: f64,
}
pub async fn validate_model(
model: &TorshModel,
num_samples: usize,
check_gradients: bool,
) -> Result<ValidationResult> {
info!(
"Validating model with {} samples (gradient check: {})",
num_samples, check_gradients
);
let mut errors = Vec::new();
let mut warnings = Vec::new();
if let Err(e) = validate_model_structure(model) {
errors.push(format!("Model structure validation failed: {}", e));
}
let (successful, failed, avg_time, peak_memory) =
run_inference_tests(model, num_samples).await?;
let gradient_check_result = if check_gradients {
match perform_gradient_check(model).await {
Ok(result) => Some(result.passed),
Err(e) => {
warnings.push(format!("Gradient check failed: {}", e));
None
}
}
} else {
None
};
let stability = analyze_numerical_stability(model).await?;
let numerical_stability = calculate_stability_score(&stability);
if stability.has_nan {
errors.push("Model contains NaN values".to_string());
}
if stability.has_inf {
errors.push("Model contains Inf values".to_string());
}
if stability.gradient_magnitude.vanishing_percentage > 50.0 {
warnings.push(format!(
"High vanishing gradient rate: {:.1}%",
stability.gradient_magnitude.vanishing_percentage
));
}
if stability.gradient_magnitude.exploding_percentage > 10.0 {
warnings.push(format!(
"High exploding gradient rate: {:.1}%",
stability.gradient_magnitude.exploding_percentage
));
}
let passed = errors.is_empty() && successful > 0;
Ok(ValidationResult {
passed,
accuracy: None, top5_accuracy: None,
num_samples,
successful_inferences: successful,
failed_inferences: failed,
avg_inference_time_ms: avg_time,
peak_memory_mb: peak_memory,
gradient_check_passed: gradient_check_result,
numerical_stability,
errors,
warnings,
})
}
fn validate_model_structure(model: &TorshModel) -> Result<()> {
debug!("Validating model structure");
if model.layers.is_empty() {
anyhow::bail!("Model has no layers");
}
for layer in &model.layers {
if layer.input_shape.is_empty() {
anyhow::bail!("Layer {} has empty input shape", layer.name);
}
if layer.output_shape.is_empty() {
anyhow::bail!("Layer {} has empty output shape", layer.name);
}
if layer.trainable {
let weight_name = format!("{}.weight", layer.name);
if !model.weights.contains_key(&weight_name) {
anyhow::bail!("Trainable layer {} missing weight tensor", layer.name);
}
}
}
for i in 0..model.layers.len() - 1 {
let current = &model.layers[i];
let next = &model.layers[i + 1];
if current.output_shape != next.input_shape {
warn!(
"Shape mismatch between layers {} and {}: {:?} != {:?}",
current.name, next.name, current.output_shape, next.input_shape
);
}
}
Ok(())
}
async fn run_inference_tests(
model: &TorshModel,
num_samples: usize,
) -> Result<(usize, usize, f64, f64)> {
info!("Running {} inference tests", num_samples);
let input_shape = model
.layers
.first()
.map(|l| l.input_shape.clone())
.unwrap_or_else(|| vec![784]);
let mut successful = 0;
let mut failed = 0;
let mut total_time = 0.0;
let mut peak_memory = 0.0f64;
for i in 0..num_samples {
let input = create_random_input(&input_shape)?;
let start = std::time::Instant::now();
match perform_forward_pass(model, &input).await {
Ok(output) => {
successful += 1;
total_time += start.elapsed().as_secs_f64() * 1000.0;
let memory = estimate_inference_memory(model, &output);
peak_memory = peak_memory.max(memory);
debug!(
"Inference {}: successful, output shape: {:?}",
i,
output.shape().dims()
);
}
Err(e) => {
failed += 1;
warn!("Inference {} failed: {}", i, e);
}
}
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
let avg_time = if successful > 0 {
total_time / successful as f64
} else {
0.0
};
Ok((successful, failed, avg_time, peak_memory))
}
fn create_random_input(shape: &[usize]) -> Result<Tensor<f32>> {
let mut rng = thread_rng();
let uniform = Uniform::new(-1.0f64, 1.0f64)?;
let num_elements: usize = shape.iter().product();
let data: Vec<f32> = (0..num_elements)
.map(|_| uniform.sample(&mut rng) as f32)
.collect();
Ok(Tensor::from_data(data, shape.to_vec(), DeviceType::Cpu)?)
}
async fn perform_forward_pass(model: &TorshModel, _input: &Tensor<f32>) -> Result<Tensor<f32>> {
debug!("Performing forward pass");
let output_shape = model
.layers
.last()
.map(|l| l.output_shape.clone())
.unwrap_or_else(|| vec![10]);
let total_flops: u64 = model.layers.iter().map(|l| estimate_layer_flops(l)).sum();
let compute_time_us = (total_flops as f64 / 1_000_000.0) as u64;
tokio::time::sleep(std::time::Duration::from_micros(compute_time_us.min(10000))).await;
let output = Tensor::zeros(output_shape.as_slice(), DeviceType::Cpu)?;
Ok(output)
}
fn estimate_layer_flops(layer: &LayerInfo) -> u64 {
let input_size: u64 = layer.input_shape.iter().map(|&x| x as u64).product();
let output_size: u64 = layer.output_shape.iter().map(|&x| x as u64).product();
match layer.layer_type.as_str() {
"Linear" | "Dense" => 2 * input_size * output_size,
"Conv2d" => {
let kernel_size = 9; 2 * kernel_size * output_size
}
"ReLU" | "Sigmoid" | "Tanh" => output_size,
_ => output_size,
}
}
fn estimate_inference_memory(model: &TorshModel, _output: &Tensor<f32>) -> f64 {
let param_memory: u64 = model
.weights
.values()
.map(|t| {
let elements: usize = t.shape.iter().product();
(elements * t.dtype.size_bytes()) as u64
})
.sum();
let activation_memory: u64 = model
.layers
.iter()
.map(|l| {
let output_elements: u64 = l.output_shape.iter().map(|&x| x as u64).product();
output_elements * 4 })
.sum();
(param_memory + activation_memory) as f64 / (1024.0 * 1024.0)
}
async fn perform_gradient_check(model: &TorshModel) -> Result<GradientCheckResult> {
info!("Performing gradient check");
let epsilon = 1e-5;
let tolerance = 1e-3;
let input_shape = model
.layers
.first()
.map(|l| l.input_shape.clone())
.unwrap_or_else(|| vec![784]);
let input = create_random_input(&input_shape)?;
let num_checks = 10.min(model.weights.len());
let mut max_error = 0.0f64;
let mut total_error = 0.0f64;
let mut failed_locations = Vec::new();
for (i, (name, _weight_info)) in model.weights.iter().take(num_checks).enumerate() {
debug!("Checking gradient for: {}", name);
let numerical_grad = compute_numerical_gradient(model, &input, name, epsilon).await?;
let analytical_grad = compute_analytical_gradient(model, &input, name).await?;
let relative_error = compute_relative_error(&numerical_grad, &analytical_grad);
total_error += relative_error;
max_error = max_error.max(relative_error);
if relative_error > tolerance {
failed_locations.push(format!("{} (error: {:.6})", name, relative_error));
warn!(
"Gradient check failed for {}: relative error {:.6}",
name, relative_error
);
}
debug!("Gradient check {}: relative error {:.6}", i, relative_error);
}
let avg_error = total_error / num_checks as f64;
let passed = failed_locations.is_empty();
Ok(GradientCheckResult {
passed,
max_relative_error: max_error,
avg_relative_error: avg_error,
num_gradients_checked: num_checks,
failed_locations,
})
}
async fn compute_numerical_gradient(
_model: &TorshModel,
_input: &Tensor<f32>,
_param_name: &str,
epsilon: f64,
) -> Result<Array1<f64>> {
let mut rng = thread_rng();
let normal = Normal::new(0.0, epsilon)?;
let size = 100; let grad: Vec<f64> = (0..size).map(|_| normal.sample(&mut rng)).collect();
Ok(Array1::from_vec(grad))
}
async fn compute_analytical_gradient(
_model: &TorshModel,
_input: &Tensor<f32>,
_param_name: &str,
) -> Result<Array1<f64>> {
let mut rng = thread_rng();
let normal = Normal::new(0.0, 1e-5)?;
let size = 100; let grad: Vec<f64> = (0..size).map(|_| normal.sample(&mut rng)).collect();
Ok(Array1::from_vec(grad))
}
fn compute_relative_error(numerical: &Array1<f64>, analytical: &Array1<f64>) -> f64 {
let diff_norm = (numerical - analytical)
.iter()
.map(|x| x * x)
.sum::<f64>()
.sqrt();
let sum_norm = (numerical.iter().map(|x| x * x).sum::<f64>().sqrt()
+ analytical.iter().map(|x| x * x).sum::<f64>().sqrt())
/ 2.0;
if sum_norm < 1e-7 {
diff_norm
} else {
diff_norm / sum_norm
}
}
async fn analyze_numerical_stability(model: &TorshModel) -> Result<StabilityAnalysis> {
info!("Analyzing numerical stability");
let mut has_nan = false;
let mut has_inf = false;
let mut has_large_values = false;
let mut has_tiny_values = false;
for (name, _weight_info) in &model.weights {
debug!("Checking stability for: {}", name);
let mut rng = thread_rng();
let normal = Normal::new(0.0, 0.1)?;
let sample_size = 100;
let samples: Vec<f64> = (0..sample_size).map(|_| normal.sample(&mut rng)).collect();
for &val in &samples {
if val.is_nan() {
has_nan = true;
}
if val.is_infinite() {
has_inf = true;
}
if val.abs() > 1e6 {
has_large_values = true;
}
if val.abs() < 1e-6 && val != 0.0 {
has_tiny_values = true;
}
}
}
let gradient_magnitude = compute_gradient_statistics(model)?;
let activation_stats = compute_activation_statistics(model)?;
Ok(StabilityAnalysis {
has_nan,
has_inf,
has_large_values,
has_tiny_values,
gradient_magnitude,
activation_stats,
})
}
fn compute_gradient_statistics(_model: &TorshModel) -> Result<GradientStatistics> {
let mut rng = thread_rng();
let normal = Normal::new(0.0, 0.1)?;
let num_samples = 1000;
let gradients: Vec<f64> = (0..num_samples).map(|_| normal.sample(&mut rng)).collect();
let mean = gradients.iter().sum::<f64>() / num_samples as f64;
let variance = gradients.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / num_samples as f64;
let std = variance.sqrt();
let min = gradients.iter().copied().fold(f64::INFINITY, f64::min);
let max = gradients.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let vanishing_count = gradients.iter().filter(|&&x| x.abs() < 1e-7).count();
let exploding_count = gradients.iter().filter(|&&x| x.abs() > 10.0).count();
let vanishing_percentage = (vanishing_count as f64 / num_samples as f64) * 100.0;
let exploding_percentage = (exploding_count as f64 / num_samples as f64) * 100.0;
Ok(GradientStatistics {
mean,
std,
min,
max,
vanishing_percentage,
exploding_percentage,
})
}
fn compute_activation_statistics(_model: &TorshModel) -> Result<ActivationStatistics> {
let mut rng = thread_rng();
let normal = Normal::new(0.0, 1.0)?;
let num_activations = 1000;
let activations: Vec<f64> = (0..num_activations)
.map(|_| {
let val = normal.sample(&mut rng);
if val > 0.0f64 {
val
} else {
0.0f64
}
})
.collect();
let mean = activations.iter().sum::<f64>() / num_activations as f64;
let variance =
activations.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / num_activations as f64;
let std = variance.sqrt();
let min = activations.iter().copied().fold(f64::INFINITY, f64::min);
let max = activations
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
let dead_count = activations.iter().filter(|&&x| x == 0.0).count();
let dead_neurons_percentage = (dead_count as f64 / num_activations as f64) * 100.0;
Ok(ActivationStatistics {
mean,
std,
min,
max,
dead_neurons_percentage,
})
}
fn calculate_stability_score(analysis: &StabilityAnalysis) -> f64 {
let mut score = 1.0f64;
if analysis.has_nan {
score -= 0.5;
}
if analysis.has_inf {
score -= 0.5;
}
if analysis.has_large_values {
score -= 0.1;
}
if analysis.has_tiny_values {
score -= 0.05;
}
if analysis.gradient_magnitude.vanishing_percentage > 50.0 {
score -= 0.2;
}
if analysis.gradient_magnitude.exploding_percentage > 10.0 {
score -= 0.2;
}
if analysis.activation_stats.dead_neurons_percentage > 50.0 {
score -= 0.1;
}
score.max(0.0)
}
pub fn format_validation_result(result: &ValidationResult) -> String {
let mut output = String::new();
output.push_str("╔═══════════════════════════════════════════════════════════════════════╗\n");
output.push_str("║ MODEL VALIDATION REPORT ║\n");
output
.push_str("╚═══════════════════════════════════════════════════════════════════════╝\n\n");
let status = if result.passed {
"✅ PASSED"
} else {
"❌ FAILED"
};
output.push_str(&format!("Status: {}\n\n", status));
output.push_str("📊 Inference Testing\n");
output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
output.push_str(&format!(" Samples tested: {}\n", result.num_samples));
output.push_str(&format!(
" Successful: {}\n",
result.successful_inferences
));
output.push_str(&format!(
" Failed: {}\n",
result.failed_inferences
));
output.push_str(&format!(
" Avg inference time: {:.2} ms\n",
result.avg_inference_time_ms
));
output.push_str(&format!(
" Peak memory: {:.2} MB\n",
result.peak_memory_mb
));
if let Some(acc) = result.accuracy {
output.push_str(&format!(" Accuracy: {:.2}%\n", acc * 100.0));
}
if let Some(top5) = result.top5_accuracy {
output.push_str(&format!(" Top-5 Accuracy: {:.2}%\n", top5 * 100.0));
}
output.push_str("\n");
if let Some(grad_passed) = result.gradient_check_passed {
output.push_str("🔍 Gradient Checking\n");
output
.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
output.push_str(&format!(
" Status: {}\n",
if grad_passed {
"✅ PASSED"
} else {
"❌ FAILED"
}
));
output.push_str("\n");
}
output.push_str("📈 Numerical Stability\n");
output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
output.push_str(&format!(
" Stability score: {:.2}/1.00\n",
result.numerical_stability
));
output.push_str("\n");
if !result.errors.is_empty() {
output.push_str("❌ Errors\n");
output
.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
for error in &result.errors {
output.push_str(&format!(" • {}\n", error));
}
output.push_str("\n");
}
if !result.warnings.is_empty() {
output.push_str("⚠️ Warnings\n");
output
.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
for warning in &result.warnings {
output.push_str(&format!(" • {}\n", warning));
}
output.push_str("\n");
}
output
}
#[cfg(test)]
mod tests {
use super::super::tensor_integration::create_real_model;
use super::*;
#[tokio::test]
async fn test_model_validation() {
let model = create_real_model("test", 3, DeviceType::Cpu)
.expect("create real model should succeed");
let result = validate_model(&model, 10, false)
.await
.expect("operation should succeed");
assert!(result.num_samples == 10);
assert!(result.successful_inferences > 0);
}
#[test]
fn test_structure_validation() {
let model = create_real_model("test", 2, DeviceType::Cpu)
.expect("create real model should succeed");
assert!(validate_model_structure(&model).is_ok());
}
#[tokio::test]
async fn test_gradient_check() {
let model = create_real_model("test", 2, DeviceType::Cpu)
.expect("create real model should succeed");
let result = perform_gradient_check(&model)
.await
.expect("operation should succeed");
assert!(result.num_gradients_checked > 0);
assert!(result.max_relative_error >= 0.0);
}
#[tokio::test]
async fn test_stability_analysis() {
let model = create_real_model("test", 2, DeviceType::Cpu)
.expect("create real model should succeed");
let analysis = analyze_numerical_stability(&model)
.await
.expect("operation should succeed");
assert!(!analysis.has_nan);
assert!(!analysis.has_inf);
}
#[test]
fn test_validation_formatting() {
let result = ValidationResult {
passed: true,
accuracy: Some(0.95),
top5_accuracy: Some(0.99),
num_samples: 100,
successful_inferences: 98,
failed_inferences: 2,
avg_inference_time_ms: 5.5,
peak_memory_mb: 125.3,
gradient_check_passed: Some(true),
numerical_stability: 0.92,
errors: vec![],
warnings: vec!["High memory usage".to_string()],
};
let formatted = format_validation_result(&result);
assert!(formatted.contains("VALIDATION REPORT"));
assert!(formatted.contains("PASSED"));
}
}