1use anyhow::{anyhow, Result};
2use diffx_core::{diff as base_diff, DiffOptions as BaseDiffOptions};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fs;
6use std::path::Path;
7
8use crate::ml_analysis::{
9 analyze_activation_pattern_analysis, analyze_attention_patterns,
10 analyze_batch_normalization_analysis, analyze_convergence_patterns, analyze_ensemble_patterns,
11 analyze_gradient_patterns, analyze_learning_rate_changes, analyze_memory_usage_changes,
12 analyze_model_architecture_changes, analyze_model_complexity_assessment,
13 analyze_quantization_patterns, analyze_regularization_impact, analyze_training_metrics,
14 analyze_weight_distribution_analysis,
15};
16use crate::parsers::{detect_format_from_path, parse_file_by_format};
17use crate::types::{DiffOptions, DiffResult, TensorStats};
18
19pub fn diff_paths(
30 old_path: &str,
31 new_path: &str,
32 options: Option<&DiffOptions>,
33) -> Result<Vec<DiffResult>> {
34 let path1 = Path::new(old_path);
35 let path2 = Path::new(new_path);
36
37 match (path1.is_dir(), path2.is_dir()) {
38 (true, true) => diff_directories(path1, path2, options),
39 (false, false) => diff_files(path1, path2, options),
40 (true, false) => Err(anyhow!(
41 "Cannot compare directory '{}' with file '{}'",
42 old_path,
43 new_path
44 )),
45 (false, true) => Err(anyhow!(
46 "Cannot compare file '{}' with directory '{}'",
47 old_path,
48 new_path
49 )),
50 }
51}
52
53pub fn diff(old: &Value, new: &Value, options: Option<&DiffOptions>) -> Result<Vec<DiffResult>> {
58 let default_options = DiffOptions::default();
59 let opts = options.unwrap_or(&default_options);
60
61 let base_opts = convert_to_base_options(opts);
63 let base_results = base_diff(old, new, Some(&base_opts))?;
64
65 let mut results: Vec<DiffResult> = base_results.into_iter().map(|r| r.into()).collect();
67
68 if should_analyze_ml_features(old, new, opts) {
70 analyze_ml_features(old, new, &mut results, opts)?;
71 }
72
73 Ok(results)
74}
75
76fn convert_to_base_options(opts: &DiffOptions) -> BaseDiffOptions {
78 BaseDiffOptions {
79 epsilon: opts.epsilon,
80 array_id_key: opts.array_id_key.clone(),
81 ignore_keys_regex: opts.ignore_keys_regex.clone(),
82 path_filter: opts.path_filter.clone(),
83 recursive: None,
84 output_format: opts.output_format.map(|f| f.to_base_format()),
85 diffx_options: None,
86 }
87}
88
89fn should_analyze_ml_features(old: &Value, new: &Value, _opts: &DiffOptions) -> bool {
91 if let (Value::Object(old_obj), Value::Object(new_obj)) = (old, new) {
93 let pytorch_keys = [
95 "binary_size",
96 "file_size",
97 "detected_components",
98 "estimated_layers",
99 "structure_fingerprint",
100 "pickle_protocol",
101 "state_dict",
102 "model",
103 "optimizer",
104 "scheduler",
105 "epoch",
106 "loss",
107 "accuracy",
108 ];
109 for key in &pytorch_keys {
110 if old_obj.contains_key(*key) || new_obj.contains_key(*key) {
111 return true;
112 }
113 }
114
115 let safetensors_keys = ["tensors"];
117 for key in &safetensors_keys {
118 if old_obj.contains_key(*key) || new_obj.contains_key(*key) {
119 return true;
120 }
121 }
122
123 let tensor_keys = [
125 "weight",
126 "bias",
127 "running_mean",
128 "running_var",
129 "num_batches_tracked",
130 ];
131 for (key, _) in old_obj.iter().chain(new_obj.iter()) {
132 for tensor_key in &tensor_keys {
133 if key.contains(tensor_key) {
134 return true;
135 }
136 }
137 }
138
139 for (key, _) in old_obj.iter().chain(new_obj.iter()) {
141 if key.starts_with("tensors.") || key.contains(".weight") || key.contains(".bias") {
142 return true;
143 }
144 }
145 }
146
147 true
149}
150
151fn analyze_ml_features(
153 old: &Value,
154 new: &Value,
155 results: &mut Vec<DiffResult>,
156 _options: &DiffOptions,
157) -> Result<()> {
158 if let (Value::Object(old_obj), Value::Object(new_obj)) = (old, new) {
160 for (key, old_val) in old_obj {
162 if let Some(new_val) = new_obj.get(key) {
163 if is_tensor_like(old_val) && is_tensor_like(new_val) {
164 analyze_tensor_changes(key, old_val, new_val, results);
165 }
166 }
167 }
168
169 analyze_nested_tensor_containers(old_obj, new_obj, results);
171
172 analyze_model_architecture_changes(old, new, results);
174 analyze_learning_rate_changes(old, new, results);
175 analyze_convergence_patterns(old, new, results);
176 analyze_memory_usage_changes(old, new, results);
177 analyze_ensemble_patterns(old, new, results);
178 analyze_quantization_patterns(old, new, results);
179 analyze_attention_patterns(old, new, results);
180 analyze_gradient_patterns(old, new, results);
181
182 analyze_batch_normalization_analysis(old, new, results);
184 analyze_regularization_impact(old, new, results);
185 analyze_activation_pattern_analysis(old, new, results);
186 analyze_weight_distribution_analysis(old, new, results);
187 analyze_model_complexity_assessment(old, new, results);
188
189 analyze_training_metrics(old, new, results);
191 }
192
193 Ok(())
194}
195
196fn diff_files(
197 path1: &Path,
198 path2: &Path,
199 options: Option<&DiffOptions>,
200) -> Result<Vec<DiffResult>> {
201 let format1 = detect_format_from_path(path1)?;
203 let format2 = detect_format_from_path(path2)?;
204
205 if std::mem::discriminant(&format1) != std::mem::discriminant(&format2) {
207 return Err(anyhow!(
208 "Cannot compare files with different formats: {:?} vs {:?}",
209 format1,
210 format2
211 ));
212 }
213
214 let value1 = parse_file_by_format(path1, format1)?;
216 let value2 = parse_file_by_format(path2, format2)?;
217
218 diff(&value1, &value2, options)
220}
221
222fn diff_directories(
223 dir1: &Path,
224 dir2: &Path,
225 options: Option<&DiffOptions>,
226) -> Result<Vec<DiffResult>> {
227 let mut results = Vec::new();
228
229 let files1 = get_all_files_recursive(dir1)?;
231 let files2 = get_all_files_recursive(dir2)?;
232
233 let files1_map: HashMap<String, &Path> = files1
235 .iter()
236 .filter_map(|path| {
237 path.strip_prefix(dir1)
238 .ok()
239 .map(|rel| (rel.to_string_lossy().to_string(), path.as_path()))
240 })
241 .collect();
242
243 let files2_map: HashMap<String, &Path> = files2
244 .iter()
245 .filter_map(|path| {
246 path.strip_prefix(dir2)
247 .ok()
248 .map(|rel| (rel.to_string_lossy().to_string(), path.as_path()))
249 })
250 .collect();
251
252 for (rel_path, abs_path1) in &files1_map {
254 if !files2_map.contains_key(rel_path) {
255 if let Ok(format) = detect_format_from_path(abs_path1) {
256 if let Ok(value) = parse_file_by_format(abs_path1, format) {
257 results.push(DiffResult::Removed(rel_path.clone(), value));
258 }
259 }
260 }
261 }
262
263 for (rel_path, abs_path2) in &files2_map {
265 if !files1_map.contains_key(rel_path) {
266 if let Ok(format) = detect_format_from_path(abs_path2) {
267 if let Ok(value) = parse_file_by_format(abs_path2, format) {
268 results.push(DiffResult::Added(rel_path.clone(), value));
269 }
270 }
271 }
272 }
273
274 for (rel_path, abs_path1) in &files1_map {
276 if let Some(abs_path2) = files2_map.get(rel_path) {
277 match diff_files(abs_path1, abs_path2, options) {
278 Ok(mut file_results) => {
279 for result in &mut file_results {
281 match result {
282 DiffResult::Added(path, _) => *path = format!("{rel_path}/{path}"),
283 DiffResult::Removed(path, _) => *path = format!("{rel_path}/{path}"),
284 DiffResult::Modified(path, _, _) => {
285 *path = format!("{rel_path}/{path}")
286 }
287 DiffResult::TypeChanged(path, _, _) => {
288 *path = format!("{rel_path}/{path}")
289 }
290 DiffResult::TensorShapeChanged(path, _, _) => {
292 *path = format!("{rel_path}/{path}")
293 }
294 DiffResult::TensorStatsChanged(path, _, _) => {
295 *path = format!("{rel_path}/{path}")
296 }
297 DiffResult::TensorDataChanged(path, _, _) => {
298 *path = format!("{rel_path}/{path}")
299 }
300 DiffResult::ModelArchitectureChanged(path, _, _) => {
301 *path = format!("{rel_path}/{path}")
302 }
303 DiffResult::WeightSignificantChange(path, _) => {
304 *path = format!("{rel_path}/{path}")
305 }
306 DiffResult::ActivationFunctionChanged(path, _, _) => {
307 *path = format!("{rel_path}/{path}")
308 }
309 DiffResult::LearningRateChanged(path, _, _) => {
310 *path = format!("{rel_path}/{path}")
311 }
312 DiffResult::OptimizerChanged(path, _, _) => {
313 *path = format!("{rel_path}/{path}")
314 }
315 DiffResult::LossChange(path, _, _) => {
316 *path = format!("{rel_path}/{path}")
317 }
318 DiffResult::AccuracyChange(path, _, _) => {
319 *path = format!("{rel_path}/{path}")
320 }
321 DiffResult::ModelVersionChanged(path, _, _) => {
322 *path = format!("{rel_path}/{path}")
323 }
324 }
325 }
326 results.extend(file_results);
327 }
328 Err(_) => {
329 continue;
331 }
332 }
333 }
334 }
335
336 Ok(results)
337}
338
339fn get_all_files_recursive(dir: &Path) -> Result<Vec<std::path::PathBuf>> {
340 let mut files = Vec::new();
341
342 if dir.is_dir() {
343 for entry in fs::read_dir(dir)? {
344 let entry = entry?;
345 let path = entry.path();
346
347 if path.is_dir() {
348 files.extend(get_all_files_recursive(&path)?);
349 } else if path.is_file() {
350 files.push(path);
351 }
352 }
353 }
354
355 Ok(files)
356}
357
358fn is_tensor_like(value: &Value) -> bool {
360 if let Value::Object(obj) = value {
361 let has_shape =
363 obj.contains_key("shape") || obj.contains_key("dims") || obj.contains_key("size");
364 let has_data =
365 obj.contains_key("data") || obj.contains_key("values") || obj.contains_key("tensor");
366 let has_dtype = obj.contains_key("dtype")
367 || obj.contains_key("type")
368 || obj.contains_key("element_type");
369
370 has_shape && (has_data || has_dtype) ||
372 obj.contains_key("weight") || obj.contains_key("bias") ||
374 obj.contains_key("mean") || obj.contains_key("std") ||
375 obj.contains_key("min") || obj.contains_key("max")
376 } else {
377 false
378 }
379}
380
381fn analyze_tensor_changes(
383 path: &str,
384 old_tensor: &Value,
385 new_tensor: &Value,
386 results: &mut Vec<DiffResult>,
387) {
388 if let (Some(old_data), Some(new_data)) = (
390 extract_tensor_data(old_tensor),
391 extract_tensor_data(new_tensor),
392 ) {
393 let old_shape = extract_tensor_shape(old_tensor).unwrap_or_default();
394 let new_shape = extract_tensor_shape(new_tensor).unwrap_or_default();
395 let dtype = extract_tensor_dtype(old_tensor).unwrap_or_else(|| "f32".to_string());
396
397 if old_shape != new_shape {
399 results.push(DiffResult::TensorShapeChanged(
400 path.to_string(),
401 old_shape,
402 new_shape,
403 ));
404 return;
405 }
406
407 let old_stats = TensorStats::new(&old_data, old_shape.clone(), dtype.clone());
409 let new_stats = TensorStats::new(&new_data, new_shape, dtype);
410
411 if stats_changed_significantly(&old_stats, &new_stats) {
413 results.push(DiffResult::TensorStatsChanged(
414 path.to_string(),
415 old_stats,
416 new_stats,
417 ));
418 } else {
419 results.push(DiffResult::TensorDataChanged(
421 path.to_string(),
422 old_stats.mean,
423 new_stats.mean,
424 ));
425 }
426 }
427}
428
429pub fn extract_tensor_data(tensor: &Value) -> Option<Vec<f64>> {
430 match tensor {
431 Value::Array(arr) => {
433 let mut data = Vec::new();
434 extract_numbers_from_nested_array(arr, &mut data);
435 if !data.is_empty() {
436 Some(data)
437 } else {
438 None
439 }
440 }
441
442 Value::Object(obj) => {
444 let data_fields = ["data", "values", "tensor", "_data", "storage"];
446 for field in &data_fields {
447 if let Some(data_value) = obj.get(*field) {
448 if let Some(extracted) = extract_tensor_data(data_value) {
449 return Some(extracted);
450 }
451 }
452 }
453
454 if let Some(data_str) = obj.get("data").and_then(|v| v.as_str()) {
456 if let Ok(decoded) = base64_decode_tensor_data(data_str) {
457 return Some(decoded);
458 }
459 }
460
461 if let Some(data_str) = obj.get("hex_data").and_then(|v| v.as_str()) {
463 if let Ok(decoded) = hex_decode_tensor_data(data_str) {
464 return Some(decoded);
465 }
466 }
467
468 if obj.contains_key("requires_grad") || obj.contains_key("grad_fn") {
470 if let Some(Value::Array(shape)) = obj.get("shape") {
472 if let Some(flattened) = extract_flattened_tensor_values(obj, shape) {
473 return Some(flattened);
474 }
475 }
476 }
477
478 None
479 }
480
481 Value::Number(num) => {
483 if let Some(f) = num.as_f64() {
484 Some(vec![f])
485 } else {
486 None
487 }
488 }
489
490 _ => None,
491 }
492}
493
494fn extract_numbers_from_nested_array(arr: &[Value], result: &mut Vec<f64>) {
496 for item in arr {
497 match item {
498 Value::Number(num) => {
499 if let Some(f) = num.as_f64() {
500 result.push(f);
501 }
502 }
503 Value::Array(nested_arr) => {
504 extract_numbers_from_nested_array(nested_arr, result);
505 }
506 _ => {}
507 }
508 }
509}
510
511fn base64_decode_tensor_data(_data_str: &str) -> Result<Vec<f64>, Box<dyn std::error::Error>> {
513 Err("Base64 tensor decoding not yet implemented".into())
516}
517
518fn hex_decode_tensor_data(_data_str: &str) -> Result<Vec<f64>, Box<dyn std::error::Error>> {
520 Err("Hex tensor decoding not yet implemented".into())
522}
523
524fn extract_flattened_tensor_values(
526 obj: &serde_json::Map<String, Value>,
527 shape: &[Value],
528) -> Option<Vec<f64>> {
529 let total_elements: usize = shape
531 .iter()
532 .filter_map(|v| v.as_u64())
533 .map(|n| n as usize)
534 .product();
535
536 if total_elements == 0 {
537 return None;
538 }
539
540 let storage_fields = ["_storage", "storage", "_data"];
542 for field in &storage_fields {
543 if let Some(storage_value) = obj.get(*field) {
544 if let Some(data) = extract_tensor_data(storage_value) {
545 let limited_data: Vec<f64> = data.into_iter().take(total_elements).collect();
547 if !limited_data.is_empty() {
548 return Some(limited_data);
549 }
550 }
551 }
552 }
553
554 None
555}
556
557pub fn extract_tensor_shape(tensor: &Value) -> Option<Vec<usize>> {
558 tensor.get("shape").and_then(|s| s.as_array()).map(|arr| {
560 arr.iter()
561 .filter_map(|v| v.as_u64().map(|n| n as usize))
562 .collect()
563 })
564}
565
566fn extract_tensor_dtype(tensor: &Value) -> Option<String> {
567 tensor
569 .get("dtype")
570 .and_then(|dt| dt.as_str())
571 .map(|s| s.to_string())
572}
573
574fn stats_changed_significantly(old_stats: &TensorStats, new_stats: &TensorStats) -> bool {
575 let mean_change = (old_stats.mean - new_stats.mean).abs() / old_stats.mean.abs().max(1e-8);
576 let std_change = (old_stats.std - new_stats.std).abs() / old_stats.std.abs().max(1e-8);
577
578 mean_change > 0.01 || std_change > 0.01
580}
581
582fn analyze_nested_tensor_containers(
584 old_obj: &serde_json::Map<String, Value>,
585 new_obj: &serde_json::Map<String, Value>,
586 results: &mut Vec<DiffResult>,
587) {
588 let container_keys = [
590 "arrays",
591 "variables",
592 "tensors",
593 "model_state_dict",
594 "state_dict",
595 "layer_data",
596 "layers",
597 "weights",
598 "parameters",
599 ];
600
601 for container_key in &container_keys {
602 if let (Some(Value::Object(old_container)), Some(Value::Object(new_container))) =
603 (old_obj.get(*container_key), new_obj.get(*container_key))
604 {
605 for (name, old_item) in old_container {
607 if let Some(new_item) = new_container.get(name) {
608 let path = format!("{container_key}.{name}");
609 analyze_tensor_metadata_changes(&path, old_item, new_item, results);
610 }
611 }
612 }
613 }
614}
615
616fn analyze_tensor_metadata_changes(
618 path: &str,
619 old_item: &Value,
620 new_item: &Value,
621 results: &mut Vec<DiffResult>,
622) {
623 if let (Value::Object(old_obj), Value::Object(new_obj)) = (old_item, new_item) {
624 if let (Some(old_shape), Some(new_shape)) = (old_obj.get("shape"), new_obj.get("shape")) {
626 let old_shape_vec = extract_shape_from_value(old_shape);
627 let new_shape_vec = extract_shape_from_value(new_shape);
628
629 if old_shape_vec != new_shape_vec {
630 results.push(DiffResult::TensorShapeChanged(
631 path.to_string(),
632 old_shape_vec,
633 new_shape_vec,
634 ));
635 }
636 }
637
638 if let (Some(old_mean), Some(new_mean)) = (
640 old_obj.get("mean").and_then(|v| v.as_f64()),
641 new_obj.get("mean").and_then(|v| v.as_f64()),
642 ) {
643 if (old_mean - new_mean).abs() > 1e-10 {
644 results.push(DiffResult::TensorDataChanged(
645 path.to_string(),
646 old_mean,
647 new_mean,
648 ));
649 }
650 }
651
652 if let (Some(Value::Array(old_data)), Some(Value::Array(new_data))) =
654 (old_obj.get("data"), new_obj.get("data"))
655 {
656 let old_vals: Vec<f64> = old_data.iter().filter_map(|v| v.as_f64()).collect();
657 let new_vals: Vec<f64> = new_data.iter().filter_map(|v| v.as_f64()).collect();
658
659 if !old_vals.is_empty() && !new_vals.is_empty() {
660 let old_shape = old_obj
661 .get("shape")
662 .map(extract_shape_from_value)
663 .unwrap_or_default();
664 let new_shape = new_obj
665 .get("shape")
666 .map(extract_shape_from_value)
667 .unwrap_or_default();
668 let dtype = old_obj
669 .get("dtype")
670 .and_then(|d| d.as_str())
671 .unwrap_or("float32")
672 .to_string();
673
674 let old_stats = TensorStats::new(&old_vals, old_shape.clone(), dtype.clone());
675 let new_stats = TensorStats::new(&new_vals, new_shape, dtype);
676
677 if stats_changed_significantly(&old_stats, &new_stats) {
678 results.push(DiffResult::TensorStatsChanged(
679 path.to_string(),
680 old_stats,
681 new_stats,
682 ));
683 }
684 }
685 }
686
687 if let (Some(Value::Object(old_stats)), Some(Value::Object(new_stats))) =
689 (old_obj.get("statistics"), new_obj.get("statistics"))
690 {
691 let old_tensor_stats = extract_stats_from_object(old_stats, old_obj);
692 let new_tensor_stats = extract_stats_from_object(new_stats, new_obj);
693
694 if stats_changed_significantly(&old_tensor_stats, &new_tensor_stats) {
695 results.push(DiffResult::TensorStatsChanged(
696 path.to_string(),
697 old_tensor_stats,
698 new_tensor_stats,
699 ));
700 }
701 }
702
703 }
706}
707
708fn extract_shape_from_value(shape: &Value) -> Vec<usize> {
709 match shape {
710 Value::Array(arr) => arr
711 .iter()
712 .filter_map(|v| v.as_u64().map(|n| n as usize))
713 .collect(),
714 _ => vec![],
715 }
716}
717
718fn extract_stats_from_object(
719 stats_obj: &serde_json::Map<String, Value>,
720 parent_obj: &serde_json::Map<String, Value>,
721) -> TensorStats {
722 let mean = stats_obj
723 .get("mean")
724 .and_then(|v| v.as_f64())
725 .unwrap_or(0.0);
726 let std = stats_obj.get("std").and_then(|v| v.as_f64()).unwrap_or(0.0);
727 let min = stats_obj.get("min").and_then(|v| v.as_f64()).unwrap_or(0.0);
728 let max = stats_obj.get("max").and_then(|v| v.as_f64()).unwrap_or(0.0);
729
730 let shape = parent_obj
731 .get("shape")
732 .map(extract_shape_from_value)
733 .unwrap_or_default();
734
735 let dtype = parent_obj
736 .get("dtype")
737 .and_then(|d| d.as_str())
738 .unwrap_or("unknown")
739 .to_string();
740
741 let element_count = shape.iter().product();
742
743 TensorStats {
744 mean,
745 std,
746 min,
747 max,
748 shape,
749 dtype,
750 element_count,
751 }
752}