diffai_core/
diff.rs

1use anyhow::{anyhow, Result};
2use diffx_core::{diff as base_diff, DiffOptions as BaseDiffOptions};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::fs;
6use std::path::Path;
7
8use crate::ml_analysis::{
9    analyze_activation_pattern_analysis, analyze_attention_patterns,
10    analyze_batch_normalization_analysis, analyze_convergence_patterns, analyze_ensemble_patterns,
11    analyze_gradient_patterns, analyze_learning_rate_changes, analyze_memory_usage_changes,
12    analyze_model_architecture_changes, analyze_model_complexity_assessment,
13    analyze_quantization_patterns, analyze_regularization_impact, analyze_training_metrics,
14    analyze_weight_distribution_analysis,
15};
16use crate::parsers::{detect_format_from_path, parse_file_by_format};
17use crate::types::{DiffOptions, DiffResult, TensorStats};
18
19// ============================================================================
20// UNIFIED API - Main Function
21// ============================================================================
22
23/// Unified diff function for diffai (path-based entry point)
24///
25/// This is the main entry point that handles both files and directories automatically.
26/// - File vs File: Regular file comparison
27/// - Directory vs Directory: Recursive directory comparison  
28/// - File vs Directory: Returns error
29pub fn diff_paths(
30    old_path: &str,
31    new_path: &str,
32    options: Option<&DiffOptions>,
33) -> Result<Vec<DiffResult>> {
34    let path1 = Path::new(old_path);
35    let path2 = Path::new(new_path);
36
37    match (path1.is_dir(), path2.is_dir()) {
38        (true, true) => diff_directories(path1, path2, options),
39        (false, false) => diff_files(path1, path2, options),
40        (true, false) => Err(anyhow!(
41            "Cannot compare directory '{}' with file '{}'",
42            old_path,
43            new_path
44        )),
45        (false, true) => Err(anyhow!(
46            "Cannot compare file '{}' with directory '{}'",
47            old_path,
48            new_path
49        )),
50    }
51}
52
53/// Unified diff function for diffai (Value-based)
54///
55/// This function operates on pre-parsed JSON values.
56/// For file/directory operations, use diff_paths() instead.
57pub fn diff(old: &Value, new: &Value, options: Option<&DiffOptions>) -> Result<Vec<DiffResult>> {
58    let default_options = DiffOptions::default();
59    let opts = options.unwrap_or(&default_options);
60
61    // diffx-coreの基本diff機能を活用してコード重複を削減
62    let base_opts = convert_to_base_options(opts);
63    let base_results = base_diff(old, new, Some(&base_opts))?;
64
65    // diffx-coreの結果をdiffai形式に変換
66    let mut results: Vec<DiffResult> = base_results.into_iter().map(|r| r.into()).collect();
67
68    // AI/ML分析が有効な場合のみ追加処理を実行
69    if should_analyze_ml_features(old, new, opts) {
70        analyze_ml_features(old, new, &mut results, opts)?;
71    }
72
73    Ok(results)
74}
75
76// DiffOptionsをdiffx-coreのDiffOptionsに変換
77fn convert_to_base_options(opts: &DiffOptions) -> BaseDiffOptions {
78    BaseDiffOptions {
79        epsilon: opts.epsilon,
80        array_id_key: opts.array_id_key.clone(),
81        ignore_keys_regex: opts.ignore_keys_regex.clone(),
82        path_filter: opts.path_filter.clone(),
83        recursive: None,
84        output_format: opts.output_format.map(|f| f.to_base_format()),
85        diffx_options: None,
86    }
87}
88
89// AI/ML分析が必要かどうかを判定
90fn should_analyze_ml_features(old: &Value, new: &Value, _opts: &DiffOptions) -> bool {
91    // lawkitパターン:MLファイル形式なら常に分析実行
92    if let (Value::Object(old_obj), Value::Object(new_obj)) = (old, new) {
93        // PyTorchファイル構造のキーが含まれている場合
94        let pytorch_keys = [
95            "binary_size",
96            "file_size",
97            "detected_components",
98            "estimated_layers",
99            "structure_fingerprint",
100            "pickle_protocol",
101            "state_dict",
102            "model",
103            "optimizer",
104            "scheduler",
105            "epoch",
106            "loss",
107            "accuracy",
108        ];
109        for key in &pytorch_keys {
110            if old_obj.contains_key(*key) || new_obj.contains_key(*key) {
111                return true;
112            }
113        }
114
115        // SafeTensorsファイル構造のキーが含まれている場合
116        let safetensors_keys = ["tensors"];
117        for key in &safetensors_keys {
118            if old_obj.contains_key(*key) || new_obj.contains_key(*key) {
119                return true;
120            }
121        }
122
123        // テンソル関連のキーが含まれている場合(直接のテンソル名)
124        let tensor_keys = [
125            "weight",
126            "bias",
127            "running_mean",
128            "running_var",
129            "num_batches_tracked",
130        ];
131        for (key, _) in old_obj.iter().chain(new_obj.iter()) {
132            for tensor_key in &tensor_keys {
133                if key.contains(tensor_key) {
134                    return true;
135                }
136            }
137        }
138
139        // テンソル階層構造の検出 (tensors.layer.weight パターン)
140        for (key, _) in old_obj.iter().chain(new_obj.iter()) {
141            if key.starts_with("tensors.") || key.contains(".weight") || key.contains(".bias") {
142                return true;
143            }
144        }
145    }
146
147    // ML関連のファイルは基本的に分析対象とする
148    true
149}
150
151// ML特徴分析を実行する統合関数
152fn analyze_ml_features(
153    old: &Value,
154    new: &Value,
155    results: &mut Vec<DiffResult>,
156    _options: &DiffOptions,
157) -> Result<()> {
158    // lawkitパターン:ML分析は常に実行(ファイル形式に応じて自動判定)
159    if let (Value::Object(old_obj), Value::Object(new_obj)) = (old, new) {
160        // state_dictなどのテンソル変更を分析
161        for (key, old_val) in old_obj {
162            if let Some(new_val) = new_obj.get(key) {
163                if is_tensor_like(old_val) && is_tensor_like(new_val) {
164                    analyze_tensor_changes(key, old_val, new_val, results);
165                }
166            }
167        }
168
169        // NumPy/MATLAB形式: arrays/variables内のネストした構造を分析
170        analyze_nested_tensor_containers(old_obj, new_obj, results);
171
172        // すべてのML分析を自動実行(lawkitパターン:ユーザー設定より規約を優先)
173        analyze_model_architecture_changes(old, new, results);
174        analyze_learning_rate_changes(old, new, results);
175        analyze_convergence_patterns(old, new, results);
176        analyze_memory_usage_changes(old, new, results);
177        analyze_ensemble_patterns(old, new, results);
178        analyze_quantization_patterns(old, new, results);
179        analyze_attention_patterns(old, new, results);
180        analyze_gradient_patterns(old, new, results);
181
182        // Additional ML analysis features
183        analyze_batch_normalization_analysis(old, new, results);
184        analyze_regularization_impact(old, new, results);
185        analyze_activation_pattern_analysis(old, new, results);
186        analyze_weight_distribution_analysis(old, new, results);
187        analyze_model_complexity_assessment(old, new, results);
188
189        // Training metrics (loss, accuracy, version)
190        analyze_training_metrics(old, new, results);
191    }
192
193    Ok(())
194}
195
196fn diff_files(
197    path1: &Path,
198    path2: &Path,
199    options: Option<&DiffOptions>,
200) -> Result<Vec<DiffResult>> {
201    // Detect formats based on file extensions
202    let format1 = detect_format_from_path(path1)?;
203    let format2 = detect_format_from_path(path2)?;
204
205    // Ensure both files have the same format
206    if std::mem::discriminant(&format1) != std::mem::discriminant(&format2) {
207        return Err(anyhow!(
208            "Cannot compare files with different formats: {:?} vs {:?}",
209            format1,
210            format2
211        ));
212    }
213
214    // Parse files based on detected formats
215    let value1 = parse_file_by_format(path1, format1)?;
216    let value2 = parse_file_by_format(path2, format2)?;
217
218    // Use existing diff implementation
219    diff(&value1, &value2, options)
220}
221
222fn diff_directories(
223    dir1: &Path,
224    dir2: &Path,
225    options: Option<&DiffOptions>,
226) -> Result<Vec<DiffResult>> {
227    let mut results = Vec::new();
228
229    // Get all files in both directories recursively
230    let files1 = get_all_files_recursive(dir1)?;
231    let files2 = get_all_files_recursive(dir2)?;
232
233    // Create maps for easier lookup (relative path -> absolute path)
234    let files1_map: HashMap<String, &Path> = files1
235        .iter()
236        .filter_map(|path| {
237            path.strip_prefix(dir1)
238                .ok()
239                .map(|rel| (rel.to_string_lossy().to_string(), path.as_path()))
240        })
241        .collect();
242
243    let files2_map: HashMap<String, &Path> = files2
244        .iter()
245        .filter_map(|path| {
246            path.strip_prefix(dir2)
247                .ok()
248                .map(|rel| (rel.to_string_lossy().to_string(), path.as_path()))
249        })
250        .collect();
251
252    // Find files that exist in dir1 but not in dir2 (removed)
253    for (rel_path, abs_path1) in &files1_map {
254        if !files2_map.contains_key(rel_path) {
255            if let Ok(format) = detect_format_from_path(abs_path1) {
256                if let Ok(value) = parse_file_by_format(abs_path1, format) {
257                    results.push(DiffResult::Removed(rel_path.clone(), value));
258                }
259            }
260        }
261    }
262
263    // Find files that exist in dir2 but not in dir1 (added)
264    for (rel_path, abs_path2) in &files2_map {
265        if !files1_map.contains_key(rel_path) {
266            if let Ok(format) = detect_format_from_path(abs_path2) {
267                if let Ok(value) = parse_file_by_format(abs_path2, format) {
268                    results.push(DiffResult::Added(rel_path.clone(), value));
269                }
270            }
271        }
272    }
273
274    // Find files that exist in both directories (compare contents)
275    for (rel_path, abs_path1) in &files1_map {
276        if let Some(abs_path2) = files2_map.get(rel_path) {
277            match diff_files(abs_path1, abs_path2, options) {
278                Ok(mut file_results) => {
279                    // Prefix all paths with the relative path
280                    for result in &mut file_results {
281                        match result {
282                            DiffResult::Added(path, _) => *path = format!("{rel_path}/{path}"),
283                            DiffResult::Removed(path, _) => *path = format!("{rel_path}/{path}"),
284                            DiffResult::Modified(path, _, _) => {
285                                *path = format!("{rel_path}/{path}")
286                            }
287                            DiffResult::TypeChanged(path, _, _) => {
288                                *path = format!("{rel_path}/{path}")
289                            }
290                            // AI/ML specific result types
291                            DiffResult::TensorShapeChanged(path, _, _) => {
292                                *path = format!("{rel_path}/{path}")
293                            }
294                            DiffResult::TensorStatsChanged(path, _, _) => {
295                                *path = format!("{rel_path}/{path}")
296                            }
297                            DiffResult::TensorDataChanged(path, _, _) => {
298                                *path = format!("{rel_path}/{path}")
299                            }
300                            DiffResult::ModelArchitectureChanged(path, _, _) => {
301                                *path = format!("{rel_path}/{path}")
302                            }
303                            DiffResult::WeightSignificantChange(path, _) => {
304                                *path = format!("{rel_path}/{path}")
305                            }
306                            DiffResult::ActivationFunctionChanged(path, _, _) => {
307                                *path = format!("{rel_path}/{path}")
308                            }
309                            DiffResult::LearningRateChanged(path, _, _) => {
310                                *path = format!("{rel_path}/{path}")
311                            }
312                            DiffResult::OptimizerChanged(path, _, _) => {
313                                *path = format!("{rel_path}/{path}")
314                            }
315                            DiffResult::LossChange(path, _, _) => {
316                                *path = format!("{rel_path}/{path}")
317                            }
318                            DiffResult::AccuracyChange(path, _, _) => {
319                                *path = format!("{rel_path}/{path}")
320                            }
321                            DiffResult::ModelVersionChanged(path, _, _) => {
322                                *path = format!("{rel_path}/{path}")
323                            }
324                        }
325                    }
326                    results.extend(file_results);
327                }
328                Err(_) => {
329                    // If file comparison fails, skip this file
330                    continue;
331                }
332            }
333        }
334    }
335
336    Ok(results)
337}
338
339fn get_all_files_recursive(dir: &Path) -> Result<Vec<std::path::PathBuf>> {
340    let mut files = Vec::new();
341
342    if dir.is_dir() {
343        for entry in fs::read_dir(dir)? {
344            let entry = entry?;
345            let path = entry.path();
346
347            if path.is_dir() {
348                files.extend(get_all_files_recursive(&path)?);
349            } else if path.is_file() {
350                files.push(path);
351            }
352        }
353    }
354
355    Ok(files)
356}
357
358// Helper function to detect tensor-like data structures
359fn is_tensor_like(value: &Value) -> bool {
360    if let Value::Object(obj) = value {
361        // Check for common tensor-like properties
362        let has_shape =
363            obj.contains_key("shape") || obj.contains_key("dims") || obj.contains_key("size");
364        let has_data =
365            obj.contains_key("data") || obj.contains_key("values") || obj.contains_key("tensor");
366        let has_dtype = obj.contains_key("dtype")
367            || obj.contains_key("type")
368            || obj.contains_key("element_type");
369
370        // Consider it tensor-like if it has at least shape and data, or if it has common ML keys
371        has_shape && (has_data || has_dtype) ||
372        // Also check for PyTorch/Safetensors/NumPy-specific keys
373        obj.contains_key("weight") || obj.contains_key("bias") ||
374        obj.contains_key("mean") || obj.contains_key("std") ||
375        obj.contains_key("min") || obj.contains_key("max")
376    } else {
377        false
378    }
379}
380
381// AI/ML specific analysis functions
382fn analyze_tensor_changes(
383    path: &str,
384    old_tensor: &Value,
385    new_tensor: &Value,
386    results: &mut Vec<DiffResult>,
387) {
388    // Try to extract tensor data and compute statistics
389    if let (Some(old_data), Some(new_data)) = (
390        extract_tensor_data(old_tensor),
391        extract_tensor_data(new_tensor),
392    ) {
393        let old_shape = extract_tensor_shape(old_tensor).unwrap_or_default();
394        let new_shape = extract_tensor_shape(new_tensor).unwrap_or_default();
395        let dtype = extract_tensor_dtype(old_tensor).unwrap_or_else(|| "f32".to_string());
396
397        // Check for shape changes first
398        if old_shape != new_shape {
399            results.push(DiffResult::TensorShapeChanged(
400                path.to_string(),
401                old_shape,
402                new_shape,
403            ));
404            return;
405        }
406
407        // Compute comprehensive statistics
408        let old_stats = TensorStats::new(&old_data, old_shape.clone(), dtype.clone());
409        let new_stats = TensorStats::new(&new_data, new_shape, dtype);
410
411        // Check if statistics changed significantly
412        if stats_changed_significantly(&old_stats, &new_stats) {
413            results.push(DiffResult::TensorStatsChanged(
414                path.to_string(),
415                old_stats,
416                new_stats,
417            ));
418        } else {
419            // Fall back to simple data change
420            results.push(DiffResult::TensorDataChanged(
421                path.to_string(),
422                old_stats.mean,
423                new_stats.mean,
424            ));
425        }
426    }
427}
428
429pub fn extract_tensor_data(tensor: &Value) -> Option<Vec<f64>> {
430    match tensor {
431        // Direct array format (NumPy, simple tensors)
432        Value::Array(arr) => {
433            let mut data = Vec::new();
434            extract_numbers_from_nested_array(arr, &mut data);
435            if !data.is_empty() {
436                Some(data)
437            } else {
438                None
439            }
440        }
441
442        // Structured tensor format (PyTorch/Safetensors)
443        Value::Object(obj) => {
444            // Check for various data field names
445            let data_fields = ["data", "values", "tensor", "_data", "storage"];
446            for field in &data_fields {
447                if let Some(data_value) = obj.get(*field) {
448                    if let Some(extracted) = extract_tensor_data(data_value) {
449                        return Some(extracted);
450                    }
451                }
452            }
453
454            // Check for base64 encoded binary data (Safetensors)
455            if let Some(data_str) = obj.get("data").and_then(|v| v.as_str()) {
456                if let Ok(decoded) = base64_decode_tensor_data(data_str) {
457                    return Some(decoded);
458                }
459            }
460
461            // Check for hex encoded binary data
462            if let Some(data_str) = obj.get("hex_data").and_then(|v| v.as_str()) {
463                if let Ok(decoded) = hex_decode_tensor_data(data_str) {
464                    return Some(decoded);
465                }
466            }
467
468            // For PyTorch state_dict format, extract actual tensor values
469            if obj.contains_key("requires_grad") || obj.contains_key("grad_fn") {
470                // This is likely a PyTorch tensor object
471                if let Some(Value::Array(shape)) = obj.get("shape") {
472                    if let Some(flattened) = extract_flattened_tensor_values(obj, shape) {
473                        return Some(flattened);
474                    }
475                }
476            }
477
478            None
479        }
480
481        // Single numerical value
482        Value::Number(num) => {
483            if let Some(f) = num.as_f64() {
484                Some(vec![f])
485            } else {
486                None
487            }
488        }
489
490        _ => None,
491    }
492}
493
494// Recursively extract numbers from nested arrays (handles multi-dimensional tensors)
495fn extract_numbers_from_nested_array(arr: &[Value], result: &mut Vec<f64>) {
496    for item in arr {
497        match item {
498            Value::Number(num) => {
499                if let Some(f) = num.as_f64() {
500                    result.push(f);
501                }
502            }
503            Value::Array(nested_arr) => {
504                extract_numbers_from_nested_array(nested_arr, result);
505            }
506            _ => {}
507        }
508    }
509}
510
511// Decode base64 encoded tensor data (common in Safetensors format)
512fn base64_decode_tensor_data(_data_str: &str) -> Result<Vec<f64>, Box<dyn std::error::Error>> {
513    // This would typically use a base64 decoder and binary format parser
514    // For now, return error to indicate unsupported format
515    Err("Base64 tensor decoding not yet implemented".into())
516}
517
518// Decode hex encoded tensor data
519fn hex_decode_tensor_data(_data_str: &str) -> Result<Vec<f64>, Box<dyn std::error::Error>> {
520    // This would typically parse hex string and convert to float values
521    Err("Hex tensor decoding not yet implemented".into())
522}
523
524// Extract flattened tensor values from PyTorch tensor object
525fn extract_flattened_tensor_values(
526    obj: &serde_json::Map<String, Value>,
527    shape: &[Value],
528) -> Option<Vec<f64>> {
529    // Calculate total elements from shape
530    let total_elements: usize = shape
531        .iter()
532        .filter_map(|v| v.as_u64())
533        .map(|n| n as usize)
534        .product();
535
536    if total_elements == 0 {
537        return None;
538    }
539
540    // Look for various ways tensor data might be stored
541    let storage_fields = ["_storage", "storage", "_data"];
542    for field in &storage_fields {
543        if let Some(storage_value) = obj.get(*field) {
544            if let Some(data) = extract_tensor_data(storage_value) {
545                // Limit to expected number of elements
546                let limited_data: Vec<f64> = data.into_iter().take(total_elements).collect();
547                if !limited_data.is_empty() {
548                    return Some(limited_data);
549                }
550            }
551        }
552    }
553
554    None
555}
556
557pub fn extract_tensor_shape(tensor: &Value) -> Option<Vec<usize>> {
558    // Extract shape information from tensor metadata
559    tensor.get("shape").and_then(|s| s.as_array()).map(|arr| {
560        arr.iter()
561            .filter_map(|v| v.as_u64().map(|n| n as usize))
562            .collect()
563    })
564}
565
566fn extract_tensor_dtype(tensor: &Value) -> Option<String> {
567    // Extract data type from tensor metadata
568    tensor
569        .get("dtype")
570        .and_then(|dt| dt.as_str())
571        .map(|s| s.to_string())
572}
573
574fn stats_changed_significantly(old_stats: &TensorStats, new_stats: &TensorStats) -> bool {
575    let mean_change = (old_stats.mean - new_stats.mean).abs() / old_stats.mean.abs().max(1e-8);
576    let std_change = (old_stats.std - new_stats.std).abs() / old_stats.std.abs().max(1e-8);
577
578    // Consider significant if relative change > 1%
579    mean_change > 0.01 || std_change > 0.01
580}
581
582// Analyze nested tensor containers (NumPy arrays, MATLAB variables, etc.)
583fn analyze_nested_tensor_containers(
584    old_obj: &serde_json::Map<String, Value>,
585    new_obj: &serde_json::Map<String, Value>,
586    results: &mut Vec<DiffResult>,
587) {
588    // Common container keys for different formats
589    let container_keys = [
590        "arrays",
591        "variables",
592        "tensors",
593        "model_state_dict",
594        "state_dict",
595        "layer_data",
596        "layers",
597        "weights",
598        "parameters",
599    ];
600
601    for container_key in &container_keys {
602        if let (Some(Value::Object(old_container)), Some(Value::Object(new_container))) =
603            (old_obj.get(*container_key), new_obj.get(*container_key))
604        {
605            // Analyze each array/tensor in the container
606            for (name, old_item) in old_container {
607                if let Some(new_item) = new_container.get(name) {
608                    let path = format!("{container_key}.{name}");
609                    analyze_tensor_metadata_changes(&path, old_item, new_item, results);
610                }
611            }
612        }
613    }
614}
615
616// Analyze tensor metadata changes (shape, dtype, statistics)
617fn analyze_tensor_metadata_changes(
618    path: &str,
619    old_item: &Value,
620    new_item: &Value,
621    results: &mut Vec<DiffResult>,
622) {
623    if let (Value::Object(old_obj), Value::Object(new_obj)) = (old_item, new_item) {
624        // Check for shape changes
625        if let (Some(old_shape), Some(new_shape)) = (old_obj.get("shape"), new_obj.get("shape")) {
626            let old_shape_vec = extract_shape_from_value(old_shape);
627            let new_shape_vec = extract_shape_from_value(new_shape);
628
629            if old_shape_vec != new_shape_vec {
630                results.push(DiffResult::TensorShapeChanged(
631                    path.to_string(),
632                    old_shape_vec,
633                    new_shape_vec,
634                ));
635            }
636        }
637
638        // Check for direct mean changes (TensorDataChanged)
639        if let (Some(old_mean), Some(new_mean)) = (
640            old_obj.get("mean").and_then(|v| v.as_f64()),
641            new_obj.get("mean").and_then(|v| v.as_f64()),
642        ) {
643            if (old_mean - new_mean).abs() > 1e-10 {
644                results.push(DiffResult::TensorDataChanged(
645                    path.to_string(),
646                    old_mean,
647                    new_mean,
648                ));
649            }
650        }
651
652        // Check for data array changes (compute stats from raw data)
653        if let (Some(Value::Array(old_data)), Some(Value::Array(new_data))) =
654            (old_obj.get("data"), new_obj.get("data"))
655        {
656            let old_vals: Vec<f64> = old_data.iter().filter_map(|v| v.as_f64()).collect();
657            let new_vals: Vec<f64> = new_data.iter().filter_map(|v| v.as_f64()).collect();
658
659            if !old_vals.is_empty() && !new_vals.is_empty() {
660                let old_shape = old_obj
661                    .get("shape")
662                    .map(extract_shape_from_value)
663                    .unwrap_or_default();
664                let new_shape = new_obj
665                    .get("shape")
666                    .map(extract_shape_from_value)
667                    .unwrap_or_default();
668                let dtype = old_obj
669                    .get("dtype")
670                    .and_then(|d| d.as_str())
671                    .unwrap_or("float32")
672                    .to_string();
673
674                let old_stats = TensorStats::new(&old_vals, old_shape.clone(), dtype.clone());
675                let new_stats = TensorStats::new(&new_vals, new_shape, dtype);
676
677                if stats_changed_significantly(&old_stats, &new_stats) {
678                    results.push(DiffResult::TensorStatsChanged(
679                        path.to_string(),
680                        old_stats,
681                        new_stats,
682                    ));
683                }
684            }
685        }
686
687        // Check for statistics changes
688        if let (Some(Value::Object(old_stats)), Some(Value::Object(new_stats))) =
689            (old_obj.get("statistics"), new_obj.get("statistics"))
690        {
691            let old_tensor_stats = extract_stats_from_object(old_stats, old_obj);
692            let new_tensor_stats = extract_stats_from_object(new_stats, new_obj);
693
694            if stats_changed_significantly(&old_tensor_stats, &new_tensor_stats) {
695                results.push(DiffResult::TensorStatsChanged(
696                    path.to_string(),
697                    old_tensor_stats,
698                    new_tensor_stats,
699                ));
700            }
701        }
702
703        // Note: data_summary changes are already captured by JSON diff as Modified results,
704        // so we don't generate TensorStatsChanged here to avoid duplication.
705    }
706}
707
708fn extract_shape_from_value(shape: &Value) -> Vec<usize> {
709    match shape {
710        Value::Array(arr) => arr
711            .iter()
712            .filter_map(|v| v.as_u64().map(|n| n as usize))
713            .collect(),
714        _ => vec![],
715    }
716}
717
718fn extract_stats_from_object(
719    stats_obj: &serde_json::Map<String, Value>,
720    parent_obj: &serde_json::Map<String, Value>,
721) -> TensorStats {
722    let mean = stats_obj
723        .get("mean")
724        .and_then(|v| v.as_f64())
725        .unwrap_or(0.0);
726    let std = stats_obj.get("std").and_then(|v| v.as_f64()).unwrap_or(0.0);
727    let min = stats_obj.get("min").and_then(|v| v.as_f64()).unwrap_or(0.0);
728    let max = stats_obj.get("max").and_then(|v| v.as_f64()).unwrap_or(0.0);
729
730    let shape = parent_obj
731        .get("shape")
732        .map(extract_shape_from_value)
733        .unwrap_or_default();
734
735    let dtype = parent_obj
736        .get("dtype")
737        .and_then(|d| d.as_str())
738        .unwrap_or("unknown")
739        .to_string();
740
741    let element_count = shape.iter().product();
742
743    TensorStats {
744        mean,
745        std,
746        min,
747        max,
748        shape,
749        dtype,
750        element_count,
751    }
752}