1use serde::Serialize;
2use serde_json::Value;
3use regex::Regex;
4use std::collections::HashMap;
5use std::path::Path;
6use anyhow::{Result, anyhow};
8use quick_xml::de::from_str;
9use csv::ReaderBuilder;
10use candle_core::Device;
12use safetensors::SafeTensors;
13
14#[derive(Debug, PartialEq, Serialize)]
15pub enum DiffResult {
16 Added(String, Value),
17 Removed(String, Value),
18 Modified(String, Value, Value),
19 TypeChanged(String, Value, Value),
20 TensorShapeChanged(String, Vec<usize>, Vec<usize>),
22 TensorStatsChanged(String, TensorStats, TensorStats),
23 ModelArchitectureChanged(String, ModelInfo, ModelInfo),
24}
25
26#[derive(Debug, Clone, PartialEq, Serialize)]
27pub struct TensorStats {
28 pub mean: f64,
29 pub std: f64,
30 pub min: f64,
31 pub max: f64,
32 pub shape: Vec<usize>,
33 pub dtype: String,
34 pub total_params: usize,
35}
36
37#[derive(Debug, Clone, PartialEq, Serialize)]
38pub struct ModelInfo {
39 pub total_parameters: usize,
40 pub layer_count: usize,
41 pub layer_types: HashMap<String, usize>,
42 pub model_size_bytes: usize,
43}
44
45pub fn diff(
46 v1: &Value,
47 v2: &Value,
48 ignore_keys_regex: Option<&Regex>,
49 epsilon: Option<f64>,
50 array_id_key: Option<&str>,
51) -> Vec<DiffResult> {
52 let mut results = Vec::new();
53
54 if !values_are_equal(v1, v2, epsilon) {
56 let type_match = match (v1, v2) {
57 (Value::Null, Value::Null) => true,
58 (Value::Bool(_), Value::Bool(_)) => true,
59 (Value::Number(_), Value::Number(_)) => true,
60 (Value::String(_), Value::String(_)) => true,
61 (Value::Array(_), Value::Array(_)) => true,
62 (Value::Object(_), Value::Object(_)) => true,
63 _ => false,
64 };
65
66 if !type_match {
67 results.push(DiffResult::TypeChanged("".to_string(), v1.clone(), v2.clone()));
68 return results; } else if v1.is_object() && v2.is_object() {
70 diff_objects("", v1.as_object().unwrap(), v2.as_object().unwrap(), &mut results, ignore_keys_regex, epsilon, array_id_key);
71 } else if v1.is_array() && v2.is_array() {
72 diff_arrays("", v1.as_array().unwrap(), v2.as_array().unwrap(), &mut results, ignore_keys_regex, epsilon, array_id_key);
73 } else {
74 results.push(DiffResult::Modified("".to_string(), v1.clone(), v2.clone()));
76 return results;
77 }
78 }
79
80 results
81}
82
83fn diff_recursive(
84 path: &str,
85 v1: &Value,
86 v2: &Value,
87 results: &mut Vec<DiffResult>,
88 ignore_keys_regex: Option<&Regex>,
89 epsilon: Option<f64>,
90 array_id_key: Option<&str>,
91) {
92 match (v1, v2) {
93 (Value::Object(map1), Value::Object(map2)) => {
94 diff_objects(path, map1, map2, results, ignore_keys_regex, epsilon, array_id_key);
95 }
96 (Value::Array(arr1), Value::Array(arr2)) => {
97 diff_arrays(path, arr1, arr2, results, ignore_keys_regex, epsilon, array_id_key);
98 }
99 _ => { }
100 }
101}
102
103fn diff_objects(
104 path: &str,
105 map1: &serde_json::Map<String, Value>,
106 map2: &serde_json::Map<String, Value>,
107 results: &mut Vec<DiffResult>,
108 ignore_keys_regex: Option<&Regex>,
109 epsilon: Option<f64>,
110 array_id_key: Option<&str>,
111) {
112 for (key, value1) in map1 {
114 let current_path = if path.is_empty() { key.clone() } else { format!("{}.{}", path, key) };
115 if let Some(regex) = ignore_keys_regex {
116 if regex.is_match(key) {
117 continue;
118 }
119 }
120 match map2.get(key) {
121 Some(value2) => {
122 if value1.is_object() && value2.is_object() || value1.is_array() && value2.is_array() {
124 diff_recursive(¤t_path, value1, value2, results, ignore_keys_regex, epsilon, array_id_key);
125 } else if !values_are_equal(value1, value2, epsilon) {
126 let type_match = match (value1, value2) {
127 (Value::Null, Value::Null) => true,
128 (Value::Bool(_), Value::Bool(_)) => true,
129 (Value::Number(_), Value::Number(_)) => true,
130 (Value::String(_), Value::String(_)) => true,
131 (Value::Array(_), Value::Array(_)) => true,
132 (Value::Object(_), Value::Object(_)) => true,
133 _ => false,
134 };
135
136 if !type_match {
137 results.push(DiffResult::TypeChanged(current_path, value1.clone(), value2.clone()));
138 } else {
139 results.push(DiffResult::Modified(current_path, value1.clone(), value2.clone()));
140 }
141 }
142 }
143 None => {
144 results.push(DiffResult::Removed(current_path, value1.clone()));
145 }
146 }
147 }
148
149 for (key, value2) in map2 {
151 if !map1.contains_key(key) {
152 let current_path = if path.is_empty() { key.clone() } else { format!("{}.{}", path, key) };
153 results.push(DiffResult::Added(current_path, value2.clone()));
154 }
155 }
156}
157
158fn diff_arrays(
159 path: &str,
160 arr1: &Vec<Value>,
161 arr2: &Vec<Value>,
162 results: &mut Vec<DiffResult>,
163 ignore_keys_regex: Option<&Regex>,
164 epsilon: Option<f64>,
165 array_id_key: Option<&str>,
166) {
167 if let Some(id_key) = array_id_key {
168 let mut map1: HashMap<Value, &Value> = HashMap::new();
169 let mut no_id_elements1: Vec<(usize, &Value)> = Vec::new();
170 for (i, val) in arr1.iter().enumerate() {
171 if let Some(id_val) = val.get(id_key) {
172 map1.insert(id_val.clone(), val);
173 } else {
174 no_id_elements1.push((i, val));
175 }
176 }
177
178 let mut map2: HashMap<Value, &Value> = HashMap::new();
179 let mut no_id_elements2: Vec<(usize, &Value)> = Vec::new();
180 for (i, val) in arr2.iter().enumerate() {
181 if let Some(id_val) = val.get(id_key) {
182 map2.insert(id_val.clone(), val);
183 } else {
184 no_id_elements2.push((i, val));
185 }
186 }
187
188 for (id_val, val1) in &map1 {
190 let current_path = format!("{}[{}={}]", path, id_key, id_val);
191 match map2.get(&id_val) {
192 Some(val2) => {
193 if val1.is_object() && val2.is_object() || val1.is_array() && val2.is_array() {
195 diff_recursive(¤t_path, val1, val2, results, ignore_keys_regex, epsilon, array_id_key);
196 } else if !values_are_equal(val1, val2, epsilon) {
197 let type_match = match (val1, val2) {
198 (Value::Null, Value::Null) => true,
199 (Value::Bool(_), Value::Bool(_)) => true,
200 (Value::Number(_), Value::Number(_)) => true,
201 (Value::String(_), Value::String(_)) => true,
202 (Value::Array(_), Value::Array(_)) => true,
203 (Value::Object(_), Value::Object(_)) => true,
204 _ => false,
205 };
206
207 if !type_match {
208 results.push(DiffResult::TypeChanged(current_path, (*val1).clone(), (*val2).clone()));
209 } else {
210 results.push(DiffResult::Modified(current_path, (*val1).clone(), (*val2).clone()));
211 }
212 }
213 }
214 None => {
215 results.push(DiffResult::Removed(current_path, (*val1).clone()));
216 }
217 }
218 }
219
220 for (id_val, val2) in map2 {
222 if !map1.contains_key(&id_val) {
223 let current_path = format!("{}[{}={}]", path, id_key, id_val);
224 results.push(DiffResult::Added(current_path, val2.clone()));
225 }
226 }
227
228 let max_len = no_id_elements1.len().max(no_id_elements2.len());
230 for i in 0..max_len {
231 match (no_id_elements1.get(i), no_id_elements2.get(i)) {
232 (Some((idx1, val1)), Some((_idx2, val2))) => {
233 let current_path = format!("{}[{}]", path, idx1);
234 if val1.is_object() && val2.is_object() || val1.is_array() && val2.is_array() {
235 diff_recursive(¤t_path, val1, val2, results, ignore_keys_regex, epsilon, array_id_key);
236 } else if !values_are_equal(val1, val2, epsilon) {
237 let type_match = match (val1, val2) {
238 (Value::Null, Value::Null) => true,
239 (Value::Bool(_), Value::Bool(_)) => true,
240 (Value::Number(_), Value::Number(_)) => true,
241 (Value::String(_), Value::String(_)) => true,
242 (Value::Array(_), Value::Array(_)) => true,
243 (Value::Object(_), Value::Object(_)) => true,
244 _ => false,
245 };
246
247 if !type_match {
248 results.push(DiffResult::TypeChanged(current_path, (*val1).clone(), (*val2).clone()));
249 } else {
250 results.push(DiffResult::Modified(current_path, (*val1).clone(), (*val2).clone()));
251 }
252 }
253 }
254 (Some((idx1, val1)), None) => {
255 let current_path = format!("{}[{}]", path, idx1);
256 results.push(DiffResult::Removed(current_path, (*val1).clone()));
257 }
258 (None, Some((idx2, val2))) => {
259 let current_path = format!("{}[{}]", path, idx2);
260 results.push(DiffResult::Added(current_path, (*val2).clone()));
261 }
262 (None, None) => break,
263 }
264 }
265 } else {
266 let max_len = arr1.len().max(arr2.len());
268 for i in 0..max_len {
269 let current_path = format!("{}[{}]", path, i);
270 match (arr1.get(i), arr2.get(i)) {
271 (Some(val1), Some(val2)) => {
272 if val1.is_object() && val2.is_object() || val1.is_array() && val2.is_array() {
274 diff_recursive(¤t_path, val1, val2, results, ignore_keys_regex, epsilon, array_id_key);
275 } else if !values_are_equal(val1, val2, epsilon) {
276 let type_match = match (val1, val2) {
277 (Value::Null, Value::Null) => true,
278 (Value::Bool(_), Value::Bool(_)) => true,
279 (Value::Number(_), Value::Number(_)) => true,
280 (Value::String(_), Value::String(_)) => true,
281 (Value::Array(_), Value::Array(_)) => true,
282 (Value::Object(_), Value::Object(_)) => true,
283 _ => false,
284 };
285
286 if !type_match {
287 results.push(DiffResult::TypeChanged(current_path, val1.clone(), val2.clone()));
288 } else {
289 results.push(DiffResult::Modified(current_path, val1.clone(), val2.clone()));
290 }
291 }
292 }
293 (Some(val1), None) => {
294 results.push(DiffResult::Removed(current_path, val1.clone()));
295 }
296 (None, Some(val2)) => {
297 results.push(DiffResult::Added(current_path, val2.clone()));
298 }
299 (None, None) => { }
300 }
301 }
302 }
303}
304
305fn values_are_equal(v1: &Value, v2: &Value, epsilon: Option<f64>) -> bool {
306 if let (Some(e), Value::Number(n1), Value::Number(n2)) = (epsilon, v1, v2) {
307 if let (Some(f1), Some(f2)) = (n1.as_f64(), n2.as_f64()) {
308 return (f1 - f2).abs() < e;
309 }
310 }
311 v1 == v2
312}
313
314pub fn value_type_name(value: &Value) -> &str {
315 match value {
316 Value::Null => "Null",
317 Value::Bool(_) => "Boolean",
318 Value::Number(_) => "Number",
319 Value::String(_) => "String",
320 Value::Array(_) => "Array",
321 Value::Object(_) => "Object",
322 }
323}
324
325pub fn parse_ini(content: &str) -> Result<Value> {
326 use configparser::ini::Ini;
327
328 let mut ini = Ini::new();
329 ini.read(content.to_string())
330 .map_err(|e| anyhow!("Failed to parse INI: {}", e))?;
331
332 let mut root_map = serde_json::Map::new();
333
334 for section_name in ini.sections() {
335 let mut section_map = serde_json::Map::new();
336
337 if let Some(section) = ini.get_map_ref().get(§ion_name) {
338 for (key, value) in section {
339 if let Some(v) = value {
340 section_map.insert(key.clone(), Value::String(v.clone()));
341 } else {
342 section_map.insert(key.clone(), Value::Null);
343 }
344 }
345 }
346
347 root_map.insert(section_name, Value::Object(section_map));
348 }
349
350 Ok(Value::Object(root_map))
351}
352
353pub fn parse_xml(content: &str) -> Result<Value> {
354 let value: Value = from_str(content)?;
355 Ok(value)
356}
357
358pub fn parse_csv(content: &str) -> Result<Value> {
359 let mut reader = ReaderBuilder::new().from_reader(content.as_bytes());
360 let mut records = Vec::new();
361
362 let headers = reader.headers()?.clone();
363 let has_headers = !headers.is_empty();
364
365 for result in reader.into_records() {
366 let record = result?;
367 if has_headers {
368 let mut obj = serde_json::Map::new();
369 for (i, header) in headers.iter().enumerate() {
370 if let Some(value) = record.get(i) {
371 obj.insert(header.to_string(), Value::String(value.to_string()));
372 }
373 }
374 records.push(Value::Object(obj));
375 } else {
376 let mut arr = Vec::new();
377 for field in record.iter() {
378 arr.push(Value::String(field.to_string()));
379 }
380 records.push(Value::Array(arr));
381 }
382 }
383 Ok(Value::Array(records))
384}
385
386pub fn parse_pytorch_model(file_path: &Path) -> Result<HashMap<String, TensorStats>> {
392 let _device = Device::Cpu;
393 let mut model_tensors = HashMap::new();
394
395 if let Ok(data) = std::fs::read(file_path) {
397 if let Ok(safetensors) = SafeTensors::deserialize(&data) {
398 for (name, tensor_view) in safetensors.tensors() {
399 let shape: Vec<usize> = tensor_view.shape().to_vec();
400 let dtype = match tensor_view.dtype() {
401 safetensors::Dtype::F32 => "f32".to_string(),
402 safetensors::Dtype::F64 => "f64".to_string(),
403 safetensors::Dtype::I32 => "i32".to_string(),
404 safetensors::Dtype::I64 => "i64".to_string(),
405 _ => "unknown".to_string(),
406 };
407
408 let total_params = shape.iter().product();
410 let stats = TensorStats {
411 mean: 0.0, std: 0.0, min: 0.0, max: 0.0, shape,
416 dtype,
417 total_params,
418 };
419
420 model_tensors.insert(name.to_string(), stats);
421 }
422 return Ok(model_tensors);
423 }
424 }
425
426 Err(anyhow!("Failed to parse PyTorch model file: {}", file_path.display()))
430}
431
432pub fn parse_safetensors_model(file_path: &Path) -> Result<HashMap<String, TensorStats>> {
434 let data = std::fs::read(file_path)?;
435 let safetensors = SafeTensors::deserialize(&data)?;
436 let mut model_tensors = HashMap::new();
437
438 for (name, tensor_view) in safetensors.tensors() {
439 let shape: Vec<usize> = tensor_view.shape().to_vec();
440 let dtype = match tensor_view.dtype() {
441 safetensors::Dtype::F32 => "f32".to_string(),
442 safetensors::Dtype::F64 => "f64".to_string(),
443 safetensors::Dtype::I32 => "i32".to_string(),
444 safetensors::Dtype::I64 => "i64".to_string(),
445 _ => "unknown".to_string(),
446 };
447
448 let total_params = shape.iter().product();
449
450 let data_slice = tensor_view.data();
452 let (mean, std, min, max) = match tensor_view.dtype() {
453 safetensors::Dtype::F32 => {
454 let float_data: &[f32] = bytemuck::cast_slice(data_slice);
455 calculate_f32_stats(float_data)
456 },
457 safetensors::Dtype::F64 => {
458 let float_data: &[f64] = bytemuck::cast_slice(data_slice);
459 calculate_f64_stats(float_data)
460 },
461 _ => (0.0, 0.0, 0.0, 0.0), };
463
464 let stats = TensorStats {
465 mean,
466 std,
467 min,
468 max,
469 shape,
470 dtype,
471 total_params,
472 };
473
474 model_tensors.insert(name.to_string(), stats);
475 }
476
477 Ok(model_tensors)
478}
479
480pub fn diff_ml_models(
482 model1_path: &Path,
483 model2_path: &Path,
484 epsilon: Option<f64>,
485) -> Result<Vec<DiffResult>> {
486 let model1_tensors = if model1_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
487 parse_safetensors_model(model1_path)?
488 } else {
489 parse_pytorch_model(model1_path)?
490 };
491
492 let model2_tensors = if model2_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
493 parse_safetensors_model(model2_path)?
494 } else {
495 parse_pytorch_model(model2_path)?
496 };
497
498 let mut results = Vec::new();
499 let eps = epsilon.unwrap_or(1e-6);
500
501 for (name, stats) in &model2_tensors {
503 if !model1_tensors.contains_key(name) {
504 results.push(DiffResult::Added(
505 format!("tensor.{}", name),
506 serde_json::to_value(stats)?,
507 ));
508 }
509 }
510
511 for (name, stats) in &model1_tensors {
513 if !model2_tensors.contains_key(name) {
514 results.push(DiffResult::Removed(
515 format!("tensor.{}", name),
516 serde_json::to_value(stats)?,
517 ));
518 }
519 }
520
521 for (name, stats1) in &model1_tensors {
523 if let Some(stats2) = model2_tensors.get(name) {
524 if stats1.shape != stats2.shape {
526 results.push(DiffResult::TensorShapeChanged(
527 format!("tensor.{}", name),
528 stats1.shape.clone(),
529 stats2.shape.clone(),
530 ));
531 }
532
533 if (stats1.mean - stats2.mean).abs() > eps ||
535 (stats1.std - stats2.std).abs() > eps ||
536 (stats1.min - stats2.min).abs() > eps ||
537 (stats1.max - stats2.max).abs() > eps {
538 results.push(DiffResult::TensorStatsChanged(
539 format!("tensor.{}", name),
540 stats1.clone(),
541 stats2.clone(),
542 ));
543 }
544 }
545 }
546
547 Ok(results)
548}
549
550pub fn diff_ml_models_enhanced(
552 model1_path: &Path,
553 model2_path: &Path,
554 epsilon: Option<f64>,
555 show_layer_impact: bool,
556 quantization_analysis: bool,
557 detailed_stats: bool,
558) -> Result<Vec<DiffResult>> {
559 let model1_tensors = if model1_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
560 parse_safetensors_model(model1_path)?
561 } else {
562 parse_pytorch_model(model1_path)?
563 };
564
565 let model2_tensors = if model2_path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
566 parse_safetensors_model(model2_path)?
567 } else {
568 parse_pytorch_model(model2_path)?
569 };
570
571 let mut results = Vec::new();
572 let eps = epsilon.unwrap_or(1e-6);
573
574 if detailed_stats {
576 let model1_info = calculate_model_info(&model1_tensors);
577 let model2_info = calculate_model_info(&model2_tensors);
578
579 if model1_info.total_parameters != model2_info.total_parameters ||
580 model1_info.layer_count != model2_info.layer_count {
581 results.push(DiffResult::ModelArchitectureChanged(
582 "model".to_string(),
583 model1_info,
584 model2_info,
585 ));
586 }
587 }
588
589 for (name, stats) in &model2_tensors {
591 if !model1_tensors.contains_key(name) {
592 results.push(DiffResult::Added(
593 format!("tensor.{}", name),
594 serde_json::to_value(stats)?,
595 ));
596 }
597 }
598
599 for (name, stats) in &model1_tensors {
601 if !model2_tensors.contains_key(name) {
602 results.push(DiffResult::Removed(
603 format!("tensor.{}", name),
604 serde_json::to_value(stats)?,
605 ));
606 }
607 }
608
609 for (name, stats1) in &model1_tensors {
611 if let Some(stats2) = model2_tensors.get(name) {
612 if stats1.shape != stats2.shape {
614 results.push(DiffResult::TensorShapeChanged(
615 format!("tensor.{}", name),
616 stats1.shape.clone(),
617 stats2.shape.clone(),
618 ));
619 }
620
621 let mean_change = (stats1.mean - stats2.mean).abs();
623 let std_change = (stats1.std - stats2.std).abs();
624 let min_change = (stats1.min - stats2.min).abs();
625 let max_change = (stats1.max - stats2.max).abs();
626
627 let stats_changed = mean_change > eps || std_change > eps ||
628 min_change > eps || max_change > eps;
629
630 if stats_changed {
631 if show_layer_impact {
632 let impact_score = calculate_layer_impact(stats1, stats2);
634 let enhanced_key = format!("tensor.{} [impact: {:.4}]", name, impact_score);
635 results.push(DiffResult::TensorStatsChanged(
636 enhanced_key,
637 stats1.clone(),
638 stats2.clone(),
639 ));
640 } else {
641 results.push(DiffResult::TensorStatsChanged(
642 format!("tensor.{}", name),
643 stats1.clone(),
644 stats2.clone(),
645 ));
646 }
647 }
648
649 if quantization_analysis {
651 let quantization_info = analyze_quantization_impact(stats1, stats2);
652 if !quantization_info.is_empty() {
653 results.push(DiffResult::Modified(
654 format!("quantization.{}", name),
655 serde_json::to_value(&quantization_info)?,
656 serde_json::Value::Null,
657 ));
658 }
659 }
660 }
661 }
662
663 Ok(results)
664}
665
666fn calculate_layer_impact(stats1: &TensorStats, stats2: &TensorStats) -> f64 {
668 let mean_change = (stats1.mean - stats2.mean).abs();
669 let std_change = (stats1.std - stats2.std).abs();
670 let param_ratio = stats1.total_params as f64;
671
672 (mean_change + std_change) * param_ratio.log10().max(1.0)
674}
675
676fn analyze_quantization_impact(stats1: &TensorStats, stats2: &TensorStats) -> HashMap<String, f64> {
678 let mut analysis = HashMap::new();
679
680 let precision_loss = (stats1.max - stats1.min) / (stats2.max - stats2.min);
682 if precision_loss > 1.5 {
683 analysis.insert("precision_loss_ratio".to_string(), precision_loss);
684 }
685
686 let range_compression = ((stats1.max - stats1.min) - (stats2.max - stats2.min)).abs();
688 if range_compression > 0.1 {
689 analysis.insert("range_compression".to_string(), range_compression);
690 }
691
692 analysis
693}
694
695fn calculate_model_info(tensors: &HashMap<String, TensorStats>) -> ModelInfo {
697 let total_parameters: usize = tensors.values().map(|stats| stats.total_params).sum();
698 let layer_count = tensors.len();
699
700 let mut layer_types = HashMap::new();
701 for name in tensors.keys() {
702 let layer_type = extract_layer_type(name);
703 *layer_types.entry(layer_type).or_insert(0) += 1;
704 }
705
706 let model_size_bytes = total_parameters * 4;
708
709 ModelInfo {
710 total_parameters,
711 layer_count,
712 layer_types,
713 model_size_bytes,
714 }
715}
716
717fn extract_layer_type(tensor_name: &str) -> String {
719 if tensor_name.contains("conv") || tensor_name.contains("Conv") {
720 "conv".to_string()
721 } else if tensor_name.contains("linear") || tensor_name.contains("Linear") || tensor_name.contains("fc") {
722 "linear".to_string()
723 } else if tensor_name.contains("norm") || tensor_name.contains("Norm") || tensor_name.contains("bn") {
724 "norm".to_string()
725 } else if tensor_name.contains("attention") || tensor_name.contains("attn") {
726 "attention".to_string()
727 } else if tensor_name.contains("embedding") || tensor_name.contains("embed") {
728 "embedding".to_string()
729 } else {
730 "other".to_string()
731 }
732}
733
734fn calculate_f32_stats(data: &[f32]) -> (f64, f64, f64, f64) {
736 if data.is_empty() {
737 return (0.0, 0.0, 0.0, 0.0);
738 }
739
740 let sum: f64 = data.iter().map(|&x| x as f64).sum();
741 let mean = sum / data.len() as f64;
742
743 let variance: f64 = data.iter()
744 .map(|&x| {
745 let diff = x as f64 - mean;
746 diff * diff
747 })
748 .sum::<f64>() / data.len() as f64;
749
750 let std = variance.sqrt();
751 let min = data.iter().copied().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap() as f64;
752 let max = data.iter().copied().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap() as f64;
753
754 (mean, std, min, max)
755}
756
757fn calculate_f64_stats(data: &[f64]) -> (f64, f64, f64, f64) {
758 if data.is_empty() {
759 return (0.0, 0.0, 0.0, 0.0);
760 }
761
762 let sum: f64 = data.iter().sum();
763 let mean = sum / data.len() as f64;
764
765 let variance: f64 = data.iter()
766 .map(|&x| {
767 let diff = x - mean;
768 diff * diff
769 })
770 .sum::<f64>() / data.len() as f64;
771
772 let std = variance.sqrt();
773 let min = data.iter().copied().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
774 let max = data.iter().copied().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
775
776 (mean, std, min, max)
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782
783 #[test]
784 fn test_tensor_stats_creation() {
785 let stats = TensorStats {
786 mean: 0.5,
787 std: 1.0,
788 min: -2.0,
789 max: 3.0,
790 shape: vec![10, 20],
791 dtype: "f32".to_string(),
792 total_params: 200,
793 };
794
795 assert_eq!(stats.mean, 0.5);
796 assert_eq!(stats.total_params, 200);
797 assert_eq!(stats.shape, vec![10, 20]);
798 }
799
800 #[test]
801 fn test_diff_result_variants() {
802 let stats1 = TensorStats {
804 mean: 0.0,
805 std: 1.0,
806 min: -2.0,
807 max: 2.0,
808 shape: vec![128, 64],
809 dtype: "f32".to_string(),
810 total_params: 8192,
811 };
812
813 let stats2 = TensorStats {
814 mean: 0.1,
815 std: 1.1,
816 min: -1.9,
817 max: 2.1,
818 shape: vec![128, 64],
819 dtype: "f32".to_string(),
820 total_params: 8192,
821 };
822
823 let diff = DiffResult::TensorStatsChanged(
824 "linear1.weight".to_string(),
825 stats1.clone(),
826 stats2.clone()
827 );
828
829 match diff {
830 DiffResult::TensorStatsChanged(name, s1, s2) => {
831 assert_eq!(name, "linear1.weight");
832 assert_eq!(s1.mean, 0.0);
833 assert_eq!(s2.mean, 0.1);
834 },
835 _ => panic!("Expected TensorStatsChanged variant"),
836 }
837 }
838
839 #[test]
840 fn test_tensor_shape_changed() {
841 let diff = DiffResult::TensorShapeChanged(
842 "linear2.weight".to_string(),
843 vec![256, 128],
844 vec![512, 128]
845 );
846
847 match diff {
848 DiffResult::TensorShapeChanged(name, shape1, shape2) => {
849 assert_eq!(name, "linear2.weight");
850 assert_eq!(shape1, vec![256, 128]);
851 assert_eq!(shape2, vec![512, 128]);
852 },
853 _ => panic!("Expected TensorShapeChanged variant"),
854 }
855 }
856
857 #[test]
858 fn test_calculate_f32_stats() {
859 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
860 let (mean, std, min, max) = calculate_f32_stats(&data);
861
862 assert_eq!(mean, 3.0);
863 assert_eq!(min, 1.0);
864 assert_eq!(max, 5.0);
865 assert!((std - (2.0_f64).sqrt()).abs() < 1e-10);
867 }
868
869 #[test]
870 fn test_calculate_f64_stats() {
871 let data = vec![0.0, 1.0, 2.0];
872 let (mean, std, min, max) = calculate_f64_stats(&data);
873
874 assert_eq!(mean, 1.0);
875 assert_eq!(min, 0.0);
876 assert_eq!(max, 2.0);
877 assert!((std - (2.0_f64 / 3.0).sqrt()).abs() < 1e-10);
878 }
879
880 #[test]
881 fn test_error_handling_nonexistent_files() {
882 let result = diff_ml_models(
884 Path::new("nonexistent1.safetensors"),
885 Path::new("nonexistent2.safetensors"),
886 None
887 );
888 assert!(result.is_err());
889 }
890}