diffai_core/
lib.rs

1use serde::Serialize;
2use serde_json::Value;
3use regex::Regex;
4use std::collections::HashMap;
5use std::path::Path;
6// use ini::Ini;
7use anyhow::{Result, anyhow};
8use quick_xml::de::from_str;
9use csv::ReaderBuilder;
10// AI/ML dependencies
11use candle_core::Device;
12use safetensors::SafeTensors;
13
14#[derive(Debug, PartialEq, Serialize)]
15pub enum DiffResult {
16    Added(String, Value),
17    Removed(String, Value),
18    Modified(String, Value, Value),
19    TypeChanged(String, Value, Value),
20    // AI/ML specific diff results
21    TensorShapeChanged(String, Vec<usize>, Vec<usize>),
22    TensorStatsChanged(String, TensorStats, TensorStats),
23    ModelArchitectureChanged(String, ModelInfo, ModelInfo),
24}
25
26#[derive(Debug, Clone, PartialEq, Serialize)]
27pub struct TensorStats {
28    pub mean: f64,
29    pub std: f64,
30    pub min: f64,
31    pub max: f64,
32    pub shape: Vec<usize>,
33    pub dtype: String,
34    pub total_params: usize,
35}
36
37#[derive(Debug, Clone, PartialEq, Serialize)]
38pub struct ModelInfo {
39    pub total_parameters: usize,
40    pub layer_count: usize,
41    pub layer_types: HashMap<String, usize>,
42    pub model_size_bytes: usize,
43}
44
45pub fn diff(
46    v1: &Value,
47    v2: &Value,
48    ignore_keys_regex: Option<&Regex>,
49    epsilon: Option<f64>,
50    array_id_key: Option<&str>,
51) -> Vec<DiffResult> {
52    let mut results = Vec::new();
53
54    // Handle root level type or value change first
55    if !values_are_equal(v1, v2, epsilon) {
56        let type_match = match (v1, v2) {
57            (Value::Null, Value::Null) => true,
58            (Value::Bool(_), Value::Bool(_)) => true,
59            (Value::Number(_), Value::Number(_)) => true,
60            (Value::String(_), Value::String(_)) => true,
61            (Value::Array(_), Value::Array(_)) => true,
62            (Value::Object(_), Value::Object(_)) => true,
63            _ => false,
64        };
65
66        if !type_match {
67            results.push(DiffResult::TypeChanged("".to_string(), v1.clone(), v2.clone()));
68            return results; // If root type changed, no further diffing needed
69        } else if v1.is_object() && v2.is_object() {
70            diff_objects("", v1.as_object().unwrap(), v2.as_object().unwrap(), &mut results, ignore_keys_regex, epsilon, array_id_key);
71        } else if v1.is_array() && v2.is_array() {
72            diff_arrays("", v1.as_array().unwrap(), v2.as_array().unwrap(), &mut results, ignore_keys_regex, epsilon, array_id_key);
73        } else {
74            // Simple value modification at root
75            results.push(DiffResult::Modified("".to_string(), v1.clone(), v2.clone()));
76            return results;
77        }
78    }
79
80    results
81}
82
83fn diff_recursive(
84    path: &str,
85    v1: &Value,
86    v2: &Value,
87    results: &mut Vec<DiffResult>,
88    ignore_keys_regex: Option<&Regex>,
89    epsilon: Option<f64>,
90    array_id_key: Option<&str>,
91) {
92    match (v1, v2) {
93        (Value::Object(map1), Value::Object(map2)) => {
94            diff_objects(path, map1, map2, results, ignore_keys_regex, epsilon, array_id_key);
95        }
96        (Value::Array(arr1), Value::Array(arr2)) => {
97            diff_arrays(path, arr1, arr2, results, ignore_keys_regex, epsilon, array_id_key);
98        }
99        _ => { /* Should not happen if called correctly from diff_objects/diff_arrays */ }
100    }
101}
102
103fn diff_objects(
104    path: &str,
105    map1: &serde_json::Map<String, Value>,
106    map2: &serde_json::Map<String, Value>,
107    results: &mut Vec<DiffResult>,
108    ignore_keys_regex: Option<&Regex>,
109    epsilon: Option<f64>,
110    array_id_key: Option<&str>,
111) {
112    // Check for modified or removed keys
113    for (key, value1) in map1 {
114        let current_path = if path.is_empty() { key.clone() } else { format!("{}.{}", path, key) };
115        if let Some(regex) = ignore_keys_regex {
116            if regex.is_match(key) {
117                continue;
118            }
119        }
120        match map2.get(key) {
121            Some(value2) => {
122                // Recurse for nested objects/arrays
123                if value1.is_object() && value2.is_object() || value1.is_array() && value2.is_array() {
124                    diff_recursive(&current_path, value1, value2, results, ignore_keys_regex, epsilon, array_id_key);
125                } else if !values_are_equal(value1, value2, epsilon) {
126                    let type_match = match (value1, value2) {
127                        (Value::Null, Value::Null) => true,
128                        (Value::Bool(_), Value::Bool(_)) => true,
129                        (Value::Number(_), Value::Number(_)) => true,
130                        (Value::String(_), Value::String(_)) => true,
131                        (Value::Array(_), Value::Array(_)) => true,
132                        (Value::Object(_), Value::Object(_)) => true,
133                        _ => false,
134                    };
135
136                    if !type_match {
137                        results.push(DiffResult::TypeChanged(current_path, value1.clone(), value2.clone()));
138                    } else {
139                        results.push(DiffResult::Modified(current_path, value1.clone(), value2.clone()));
140                    }
141                }
142            }
143            None => {
144                results.push(DiffResult::Removed(current_path, value1.clone()));
145            }
146        }
147    }
148
149    // Check for added keys
150    for (key, value2) in map2 {
151        if !map1.contains_key(key) {
152            let current_path = if path.is_empty() { key.clone() } else { format!("{}.{}", path, key) };
153            results.push(DiffResult::Added(current_path, value2.clone()));
154        }
155    }
156}
157
158fn diff_arrays(
159    path: &str,
160    arr1: &Vec<Value>,
161    arr2: &Vec<Value>,
162    results: &mut Vec<DiffResult>,
163    ignore_keys_regex: Option<&Regex>,
164    epsilon: Option<f64>,
165    array_id_key: Option<&str>,
166) {
167    if let Some(id_key) = array_id_key {
168        let mut map1: HashMap<Value, &Value> = HashMap::new();
169        let mut no_id_elements1: Vec<(usize, &Value)> = Vec::new();
170        for (i, val) in arr1.iter().enumerate() {
171            if let Some(id_val) = val.get(id_key) {
172                map1.insert(id_val.clone(), val);
173            } else {
174                no_id_elements1.push((i, val));
175            }
176        }
177
178        let mut map2: HashMap<Value, &Value> = HashMap::new();
179        let mut no_id_elements2: Vec<(usize, &Value)> = Vec::new();
180        for (i, val) in arr2.iter().enumerate() {
181            if let Some(id_val) = val.get(id_key) {
182                map2.insert(id_val.clone(), val);
183            } else {
184                no_id_elements2.push((i, val));
185            }
186        }
187
188        // Check for modified or removed elements
189        for (id_val, val1) in &map1 {
190            let current_path = format!("{}[{}={}]", path, id_key, id_val);
191            match map2.get(&id_val) {
192                Some(val2) => {
193                    // Recurse for nested objects/arrays
194                    if val1.is_object() && val2.is_object() || val1.is_array() && val2.is_array() {
195                        diff_recursive(&current_path, val1, val2, results, ignore_keys_regex, epsilon, array_id_key);
196                    } else if !values_are_equal(val1, val2, epsilon) {
197                        let type_match = match (val1, val2) {
198                            (Value::Null, Value::Null) => true,
199                            (Value::Bool(_), Value::Bool(_)) => true,
200                            (Value::Number(_), Value::Number(_)) => true,
201                            (Value::String(_), Value::String(_)) => true,
202                            (Value::Array(_), Value::Array(_)) => true,
203                            (Value::Object(_), Value::Object(_)) => true,
204                            _ => false,
205                        };
206
207                        if !type_match {
208                            results.push(DiffResult::TypeChanged(current_path, (*val1).clone(), (*val2).clone()));
209                        } else {
210                            results.push(DiffResult::Modified(current_path, (*val1).clone(), (*val2).clone()));
211                        }
212                    }
213                }
214                None => {
215                    results.push(DiffResult::Removed(current_path, (*val1).clone()));
216                }
217            }
218        }
219
220        // Check for added elements with ID
221        for (id_val, val2) in map2 {
222            if !map1.contains_key(&id_val) {
223                let current_path = format!("{}[{}={}]", path, id_key, id_val);
224                results.push(DiffResult::Added(current_path, val2.clone()));
225            }
226        }
227
228        // Handle elements without ID using index-based comparison
229        let max_len = no_id_elements1.len().max(no_id_elements2.len());
230        for i in 0..max_len {
231            match (no_id_elements1.get(i), no_id_elements2.get(i)) {
232                (Some((idx1, val1)), Some((_idx2, val2))) => {
233                    let current_path = format!("{}[{}]", path, idx1);
234                    if val1.is_object() && val2.is_object() || val1.is_array() && val2.is_array() {
235                        diff_recursive(&current_path, val1, val2, results, ignore_keys_regex, epsilon, array_id_key);
236                    } else if !values_are_equal(val1, val2, epsilon) {
237                        let type_match = match (val1, val2) {
238                            (Value::Null, Value::Null) => true,
239                            (Value::Bool(_), Value::Bool(_)) => true,
240                            (Value::Number(_), Value::Number(_)) => true,
241                            (Value::String(_), Value::String(_)) => true,
242                            (Value::Array(_), Value::Array(_)) => true,
243                            (Value::Object(_), Value::Object(_)) => true,
244                            _ => false,
245                        };
246
247                        if !type_match {
248                            results.push(DiffResult::TypeChanged(current_path, (*val1).clone(), (*val2).clone()));
249                        } else {
250                            results.push(DiffResult::Modified(current_path, (*val1).clone(), (*val2).clone()));
251                        }
252                    }
253                }
254                (Some((idx1, val1)), None) => {
255                    let current_path = format!("{}[{}]", path, idx1);
256                    results.push(DiffResult::Removed(current_path, (*val1).clone()));
257                }
258                (None, Some((idx2, val2))) => {
259                    let current_path = format!("{}[{}]", path, idx2);
260                    results.push(DiffResult::Added(current_path, (*val2).clone()));
261                }
262                (None, None) => break,
263            }
264        }
265    } else {
266        // Fallback to index-based comparison if no id_key is provided
267        let max_len = arr1.len().max(arr2.len());
268        for i in 0..max_len {
269            let current_path = format!("{}[{}]", path, i);
270            match (arr1.get(i), arr2.get(i)) {
271                (Some(val1), Some(val2)) => {
272                    // Recurse for nested objects/arrays within arrays
273                    if val1.is_object() && val2.is_object() || val1.is_array() && val2.is_array() {
274                        diff_recursive(&current_path, val1, val2, results, ignore_keys_regex, epsilon, array_id_key);
275                    } else if !values_are_equal(val1, val2, epsilon) {
276                        let type_match = match (val1, val2) {
277                            (Value::Null, Value::Null) => true,
278                            (Value::Bool(_), Value::Bool(_)) => true,
279                            (Value::Number(_), Value::Number(_)) => true,
280                            (Value::String(_), Value::String(_)) => true,
281                            (Value::Array(_), Value::Array(_)) => true,
282                            (Value::Object(_), Value::Object(_)) => true,
283                            _ => false,
284                        };
285
286                        if !type_match {
287                            results.push(DiffResult::TypeChanged(current_path, val1.clone(), val2.clone()));
288                        } else {
289                            results.push(DiffResult::Modified(current_path, val1.clone(), val2.clone()));
290                        }
291                    }
292                }
293                (Some(val1), None) => {
294                    results.push(DiffResult::Removed(current_path, val1.clone()));
295                }
296                (None, Some(val2)) => {
297                    results.push(DiffResult::Added(current_path, val2.clone()));
298                }
299                (None, None) => { /* Should not happen */ }
300            }
301        }
302    }
303}
304
305fn values_are_equal(v1: &Value, v2: &Value, epsilon: Option<f64>) -> bool {
306    if let (Some(e), Value::Number(n1), Value::Number(n2)) = (epsilon, v1, v2) {
307        if let (Some(f1), Some(f2)) = (n1.as_f64(), n2.as_f64()) {
308            return (f1 - f2).abs() < e;
309        }
310    }
311    v1 == v2
312}
313
314pub fn value_type_name(value: &Value) -> &str {
315    match value {
316        Value::Null => "Null",
317        Value::Bool(_) => "Boolean",
318        Value::Number(_) => "Number",
319        Value::String(_) => "String",
320        Value::Array(_) => "Array",
321        Value::Object(_) => "Object",
322    }
323}
324
325pub fn parse_ini(content: &str) -> Result<Value> {
326    use configparser::ini::Ini;
327    
328    let mut ini = Ini::new();
329    ini.read(content.to_string())
330        .map_err(|e| anyhow!("Failed to parse INI: {}", e))?;
331    
332    let mut root_map = serde_json::Map::new();
333
334    for section_name in ini.sections() {
335        let mut section_map = serde_json::Map::new();
336        
337        if let Some(section) = ini.get_map_ref().get(&section_name) {
338            for (key, value) in section {
339                if let Some(v) = value {
340                    section_map.insert(key.clone(), Value::String(v.clone()));
341                } else {
342                    section_map.insert(key.clone(), Value::Null);
343                }
344            }
345        }
346        
347        root_map.insert(section_name, Value::Object(section_map));
348    }
349
350    Ok(Value::Object(root_map))
351}
352
353pub fn parse_xml(content: &str) -> Result<Value> {
354    let value: Value = from_str(content)?;
355    Ok(value)
356}
357
358pub fn parse_csv(content: &str) -> Result<Value> {
359    let mut reader = ReaderBuilder::new().from_reader(content.as_bytes());
360    let mut records = Vec::new();
361
362    let headers = reader.headers()?.clone();
363    let has_headers = !headers.is_empty();
364
365    for result in reader.into_records() {
366        let record = result?;
367        if has_headers {
368            let mut obj = serde_json::Map::new();
369            for (i, header) in headers.iter().enumerate() {
370                if let Some(value) = record.get(i) {
371                    obj.insert(header.to_string(), Value::String(value.to_string()));
372                }
373            }
374            records.push(Value::Object(obj));
375        } else {
376            let mut arr = Vec::new();
377            for field in record.iter() {
378                arr.push(Value::String(field.to_string()));
379            }
380            records.push(Value::Array(arr));
381        }
382    }
383    Ok(Value::Array(records))
384}
385
386// ============================================================================
387// AI/ML File Format Support
388// ============================================================================
389
390/// Parse a PyTorch model file (.pth, .pt) and extract tensor information
391pub fn parse_pytorch_model(file_path: &Path) -> Result<HashMap<String, TensorStats>> {
392    let _device = Device::Cpu;
393    let mut model_tensors = HashMap::new();
394    
395    // Try to load as safetensors first (more efficient)
396    if let Ok(data) = std::fs::read(file_path) {
397        if let Ok(safetensors) = SafeTensors::deserialize(&data) {
398            for (name, tensor_view) in safetensors.tensors() {
399                let shape: Vec<usize> = tensor_view.shape().to_vec();
400                let dtype = match tensor_view.dtype() {
401                    safetensors::Dtype::F32 => "f32".to_string(),
402                    safetensors::Dtype::F64 => "f64".to_string(),
403                    safetensors::Dtype::I32 => "i32".to_string(),
404                    safetensors::Dtype::I64 => "i64".to_string(),
405                    _ => "unknown".to_string(),
406                };
407                
408                // Calculate basic statistics
409                let total_params = shape.iter().product();
410                let stats = TensorStats {
411                    mean: 0.0, // TODO: Calculate actual mean from tensor data
412                    std: 0.0,  // TODO: Calculate actual std from tensor data  
413                    min: 0.0,  // TODO: Calculate actual min from tensor data
414                    max: 0.0,  // TODO: Calculate actual max from tensor data
415                    shape,
416                    dtype,
417                    total_params,
418                };
419                
420                model_tensors.insert(name.to_string(), stats);
421            }
422            return Ok(model_tensors);
423        }
424    }
425    
426    // If safetensors parsing fails, try PyTorch pickle format
427    // Note: This is a simplified implementation
428    // In practice, you'd need to use candle's PyTorch loading capabilities
429    Err(anyhow!("Failed to parse PyTorch model file: {}", file_path.display()))
430}
431
432/// Parse a Safetensors file (.safetensors) and extract tensor information  
433pub fn parse_safetensors_model(file_path: &Path) -> Result<HashMap<String, TensorStats>> {
434    let data = std::fs::read(file_path)?;
435    let safetensors = SafeTensors::deserialize(&data)?;
436    let mut model_tensors = HashMap::new();
437    
438    for (name, tensor_view) in safetensors.tensors() {
439        let shape: Vec<usize> = tensor_view.shape().to_vec();
440        let dtype = match tensor_view.dtype() {
441            safetensors::Dtype::F32 => "f32".to_string(),
442            safetensors::Dtype::F64 => "f64".to_string(),
443            safetensors::Dtype::I32 => "i32".to_string(),
444            safetensors::Dtype::I64 => "i64".to_string(),
445            _ => "unknown".to_string(),
446        };
447        
448        let total_params = shape.iter().product();
449        
450        // Extract raw data and calculate statistics
451        let data_slice = tensor_view.data();
452        let (mean, std, min, max) = match tensor_view.dtype() {
453            safetensors::Dtype::F32 => {
454                let float_data: &[f32] = bytemuck::cast_slice(data_slice);
455                calculate_f32_stats(float_data)
456            },
457            safetensors::Dtype::F64 => {
458                let float_data: &[f64] = bytemuck::cast_slice(data_slice);
459                calculate_f64_stats(float_data)
460            },
461            _ => (0.0, 0.0, 0.0, 0.0), // Skip non-float types for now
462        };
463        
464        let stats = TensorStats {
465            mean,
466            std,
467            min,
468            max,
469            shape,
470            dtype,
471            total_params,
472        };
473        
474        model_tensors.insert(name.to_string(), stats);
475    }
476    
477    Ok(model_tensors)
478}
479
480/// Compare two PyTorch/Safetensors models and return differences
481pub fn diff_ml_models(
482    model1_path: &Path,
483    model2_path: &Path,
484    epsilon: Option<f64>,
485) -> Result<Vec<DiffResult>> {
486    let model1_tensors = if model1_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
487        parse_safetensors_model(model1_path)?
488    } else {
489        parse_pytorch_model(model1_path)?
490    };
491    
492    let model2_tensors = if model2_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
493        parse_safetensors_model(model2_path)?
494    } else {
495        parse_pytorch_model(model2_path)?
496    };
497    
498    let mut results = Vec::new();
499    let eps = epsilon.unwrap_or(1e-6);
500    
501    // Check for added tensors
502    for (name, stats) in &model2_tensors {
503        if !model1_tensors.contains_key(name) {
504            results.push(DiffResult::Added(
505                format!("tensor.{}", name),
506                serde_json::to_value(stats)?,
507            ));
508        }
509    }
510    
511    // Check for removed tensors
512    for (name, stats) in &model1_tensors {
513        if !model2_tensors.contains_key(name) {
514            results.push(DiffResult::Removed(
515                format!("tensor.{}", name),
516                serde_json::to_value(stats)?,
517            ));
518        }
519    }
520    
521    // Check for modified tensors
522    for (name, stats1) in &model1_tensors {
523        if let Some(stats2) = model2_tensors.get(name) {
524            // Check shape changes
525            if stats1.shape != stats2.shape {
526                results.push(DiffResult::TensorShapeChanged(
527                    format!("tensor.{}", name),
528                    stats1.shape.clone(),
529                    stats2.shape.clone(),
530                ));
531            }
532            
533            // Check statistical changes (with epsilon tolerance)
534            if (stats1.mean - stats2.mean).abs() > eps ||
535               (stats1.std - stats2.std).abs() > eps ||
536               (stats1.min - stats2.min).abs() > eps ||
537               (stats1.max - stats2.max).abs() > eps {
538                results.push(DiffResult::TensorStatsChanged(
539                    format!("tensor.{}", name),
540                    stats1.clone(),
541                    stats2.clone(),
542                ));
543            }
544        }
545    }
546    
547    Ok(results)
548}
549
550/// Enhanced ML model comparison with additional analysis features
551pub fn diff_ml_models_enhanced(
552    model1_path: &Path,
553    model2_path: &Path,
554    epsilon: Option<f64>,
555    show_layer_impact: bool,
556    quantization_analysis: bool,
557    detailed_stats: bool,
558) -> Result<Vec<DiffResult>> {
559    let model1_tensors = if model1_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
560        parse_safetensors_model(model1_path)?
561    } else {
562        parse_pytorch_model(model1_path)?
563    };
564    
565    let model2_tensors = if model2_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
566        parse_safetensors_model(model2_path)?
567    } else {
568        parse_pytorch_model(model2_path)?
569    };
570    
571    let mut results = Vec::new();
572    let eps = epsilon.unwrap_or(1e-6);
573    
574    // Enhanced model-level analysis
575    if detailed_stats {
576        let model1_info = calculate_model_info(&model1_tensors);
577        let model2_info = calculate_model_info(&model2_tensors);
578        
579        if model1_info.total_parameters != model2_info.total_parameters ||
580           model1_info.layer_count != model2_info.layer_count {
581            results.push(DiffResult::ModelArchitectureChanged(
582                "model".to_string(),
583                model1_info,
584                model2_info,
585            ));
586        }
587    }
588    
589    // Check for added tensors
590    for (name, stats) in &model2_tensors {
591        if !model1_tensors.contains_key(name) {
592            results.push(DiffResult::Added(
593                format!("tensor.{}", name),
594                serde_json::to_value(stats)?,
595            ));
596        }
597    }
598    
599    // Check for removed tensors
600    for (name, stats) in &model1_tensors {
601        if !model2_tensors.contains_key(name) {
602            results.push(DiffResult::Removed(
603                format!("tensor.{}", name),
604                serde_json::to_value(stats)?,
605            ));
606        }
607    }
608    
609    // Check for modified tensors with enhanced analysis
610    for (name, stats1) in &model1_tensors {
611        if let Some(stats2) = model2_tensors.get(name) {
612            // Check shape changes
613            if stats1.shape != stats2.shape {
614                results.push(DiffResult::TensorShapeChanged(
615                    format!("tensor.{}", name),
616                    stats1.shape.clone(),
617                    stats2.shape.clone(),
618                ));
619            }
620            
621            // Enhanced statistical changes analysis
622            let mean_change = (stats1.mean - stats2.mean).abs();
623            let std_change = (stats1.std - stats2.std).abs();
624            let min_change = (stats1.min - stats2.min).abs();
625            let max_change = (stats1.max - stats2.max).abs();
626            
627            let stats_changed = mean_change > eps || std_change > eps || 
628                               min_change > eps || max_change > eps;
629            
630            if stats_changed {
631                if show_layer_impact {
632                    // Add layer impact information to the key
633                    let impact_score = calculate_layer_impact(stats1, stats2);
634                    let enhanced_key = format!("tensor.{} [impact: {:.4}]", name, impact_score);
635                    results.push(DiffResult::TensorStatsChanged(
636                        enhanced_key,
637                        stats1.clone(),
638                        stats2.clone(),
639                    ));
640                } else {
641                    results.push(DiffResult::TensorStatsChanged(
642                        format!("tensor.{}", name),
643                        stats1.clone(),
644                        stats2.clone(),
645                    ));
646                }
647            }
648            
649            // Quantization analysis
650            if quantization_analysis {
651                let quantization_info = analyze_quantization_impact(stats1, stats2);
652                if !quantization_info.is_empty() {
653                    results.push(DiffResult::Modified(
654                        format!("quantization.{}", name),
655                        serde_json::to_value(&quantization_info)?,
656                        serde_json::Value::Null,
657                    ));
658                }
659            }
660        }
661    }
662    
663    Ok(results)
664}
665
666/// Calculate layer impact score based on parameter changes
667fn calculate_layer_impact(stats1: &TensorStats, stats2: &TensorStats) -> f64 {
668    let mean_change = (stats1.mean - stats2.mean).abs();
669    let std_change = (stats1.std - stats2.std).abs();
670    let param_ratio = stats1.total_params as f64;
671    
672    // Weighted impact score considering parameter count and statistical changes
673    (mean_change + std_change) * param_ratio.log10().max(1.0)
674}
675
676/// Analyze quantization impact between two tensor versions
677fn analyze_quantization_impact(stats1: &TensorStats, stats2: &TensorStats) -> HashMap<String, f64> {
678    let mut analysis = HashMap::new();
679    
680    // Check if precision loss indicates quantization
681    let precision_loss = (stats1.max - stats1.min) / (stats2.max - stats2.min);
682    if precision_loss > 1.5 {
683        analysis.insert("precision_loss_ratio".to_string(), precision_loss);
684    }
685    
686    // Check for typical quantization patterns
687    let range_compression = ((stats1.max - stats1.min) - (stats2.max - stats2.min)).abs();
688    if range_compression > 0.1 {
689        analysis.insert("range_compression".to_string(), range_compression);
690    }
691    
692    analysis
693}
694
695/// Calculate overall model information from tensors
696fn calculate_model_info(tensors: &HashMap<String, TensorStats>) -> ModelInfo {
697    let total_parameters: usize = tensors.values().map(|stats| stats.total_params).sum();
698    let layer_count = tensors.len();
699    
700    let mut layer_types = HashMap::new();
701    for name in tensors.keys() {
702        let layer_type = extract_layer_type(name);
703        *layer_types.entry(layer_type).or_insert(0) += 1;
704    }
705    
706    // Estimate model size in bytes (assuming f32 = 4 bytes per parameter)
707    let model_size_bytes = total_parameters * 4;
708    
709    ModelInfo {
710        total_parameters,
711        layer_count,
712        layer_types,
713        model_size_bytes,
714    }
715}
716
717/// Extract layer type from tensor name for analysis
718fn extract_layer_type(tensor_name: &str) -> String {
719    if tensor_name.contains("conv") || tensor_name.contains("Conv") {
720        "conv".to_string()
721    } else if tensor_name.contains("linear") || tensor_name.contains("Linear") || tensor_name.contains("fc") {
722        "linear".to_string()
723    } else if tensor_name.contains("norm") || tensor_name.contains("Norm") || tensor_name.contains("bn") {
724        "norm".to_string()
725    } else if tensor_name.contains("attention") || tensor_name.contains("attn") {
726        "attention".to_string()
727    } else if tensor_name.contains("embedding") || tensor_name.contains("embed") {
728        "embedding".to_string()
729    } else {
730        "other".to_string()
731    }
732}
733
734// Helper functions for statistical calculations
735fn calculate_f32_stats(data: &[f32]) -> (f64, f64, f64, f64) {
736    if data.is_empty() {
737        return (0.0, 0.0, 0.0, 0.0);
738    }
739    
740    let sum: f64 = data.iter().map(|&x| x as f64).sum();
741    let mean = sum / data.len() as f64;
742    
743    let variance: f64 = data.iter()
744        .map(|&x| {
745            let diff = x as f64 - mean;
746            diff * diff
747        })
748        .sum::<f64>() / data.len() as f64;
749    
750    let std = variance.sqrt();
751    let min = data.iter().copied().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap() as f64;
752    let max = data.iter().copied().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap() as f64;
753    
754    (mean, std, min, max)
755}
756
757fn calculate_f64_stats(data: &[f64]) -> (f64, f64, f64, f64) {
758    if data.is_empty() {
759        return (0.0, 0.0, 0.0, 0.0);
760    }
761    
762    let sum: f64 = data.iter().sum();
763    let mean = sum / data.len() as f64;
764    
765    let variance: f64 = data.iter()
766        .map(|&x| {
767            let diff = x - mean;
768            diff * diff
769        })
770        .sum::<f64>() / data.len() as f64;
771    
772    let std = variance.sqrt();
773    let min = data.iter().copied().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
774    let max = data.iter().copied().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
775    
776    (mean, std, min, max)
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782
783    #[test]
784    fn test_tensor_stats_creation() {
785        let stats = TensorStats {
786            mean: 0.5,
787            std: 1.0,
788            min: -2.0,
789            max: 3.0,
790            shape: vec![10, 20],
791            dtype: "f32".to_string(),
792            total_params: 200,
793        };
794        
795        assert_eq!(stats.mean, 0.5);
796        assert_eq!(stats.total_params, 200);
797        assert_eq!(stats.shape, vec![10, 20]);
798    }
799
800    #[test]
801    fn test_diff_result_variants() {
802        // Test TensorStatsChanged variant
803        let stats1 = TensorStats {
804            mean: 0.0,
805            std: 1.0,
806            min: -2.0,
807            max: 2.0,
808            shape: vec![128, 64],
809            dtype: "f32".to_string(),
810            total_params: 8192,
811        };
812        
813        let stats2 = TensorStats {
814            mean: 0.1,
815            std: 1.1,
816            min: -1.9,
817            max: 2.1,
818            shape: vec![128, 64],
819            dtype: "f32".to_string(),
820            total_params: 8192,
821        };
822        
823        let diff = DiffResult::TensorStatsChanged(
824            "linear1.weight".to_string(),
825            stats1.clone(),
826            stats2.clone()
827        );
828        
829        match diff {
830            DiffResult::TensorStatsChanged(name, s1, s2) => {
831                assert_eq!(name, "linear1.weight");
832                assert_eq!(s1.mean, 0.0);
833                assert_eq!(s2.mean, 0.1);
834            },
835            _ => panic!("Expected TensorStatsChanged variant"),
836        }
837    }
838
839    #[test]
840    fn test_tensor_shape_changed() {
841        let diff = DiffResult::TensorShapeChanged(
842            "linear2.weight".to_string(),
843            vec![256, 128],
844            vec![512, 128]
845        );
846        
847        match diff {
848            DiffResult::TensorShapeChanged(name, shape1, shape2) => {
849                assert_eq!(name, "linear2.weight");
850                assert_eq!(shape1, vec![256, 128]);
851                assert_eq!(shape2, vec![512, 128]);
852            },
853            _ => panic!("Expected TensorShapeChanged variant"),
854        }
855    }
856
857    #[test]
858    fn test_calculate_f32_stats() {
859        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
860        let (mean, std, min, max) = calculate_f32_stats(&data);
861        
862        assert_eq!(mean, 3.0);
863        assert_eq!(min, 1.0);
864        assert_eq!(max, 5.0);
865        // std should be sqrt(2) for [1,2,3,4,5]
866        assert!((std - (2.0_f64).sqrt()).abs() < 1e-10);
867    }
868
869    #[test]
870    fn test_calculate_f64_stats() {
871        let data = vec![0.0, 1.0, 2.0];
872        let (mean, std, min, max) = calculate_f64_stats(&data);
873        
874        assert_eq!(mean, 1.0);
875        assert_eq!(min, 0.0);
876        assert_eq!(max, 2.0);
877        assert!((std - (2.0_f64 / 3.0).sqrt()).abs() < 1e-10);
878    }
879
880    #[test]
881    fn test_error_handling_nonexistent_files() {
882        // Test that ML diff function handles non-existent files gracefully
883        let result = diff_ml_models(
884            Path::new("nonexistent1.safetensors"),
885            Path::new("nonexistent2.safetensors"),
886            None
887        );
888        assert!(result.is_err());
889    }
890}