diffai_core/parsers/
safetensors.rs

1use anyhow::Result;
2use safetensors::{Dtype, SafeTensors};
3use serde_json::{json, Value};
4use std::path::Path;
5
6/// Parse SafeTensors model file - FOR INTERNAL USE ONLY (diffai-specific)
7pub fn parse_safetensors_model(file_path: &Path) -> Result<Value> {
8    let buffer = std::fs::read(file_path)?;
9    let safetensors = SafeTensors::deserialize(&buffer)?;
10
11    let mut result = serde_json::Map::new();
12    let mut tensors = serde_json::Map::new();
13
14    for tensor_name in safetensors.names() {
15        let tensor_view = safetensors.tensor(tensor_name)?;
16        let mut tensor_info = serde_json::Map::new();
17
18        tensor_info.insert(
19            "shape".to_string(),
20            Value::Array(
21                tensor_view
22                    .shape()
23                    .iter()
24                    .map(|&s| Value::Number(s.into()))
25                    .collect(),
26            ),
27        );
28        tensor_info.insert(
29            "dtype".to_string(),
30            Value::String(format!("{:?}", tensor_view.dtype())),
31        );
32
33        // Calculate tensor statistics from raw data
34        if let Some(stats) = compute_tensor_stats(&tensor_view) {
35            tensor_info.insert("data_summary".to_string(), stats);
36        }
37
38        tensors.insert(tensor_name.to_string(), Value::Object(tensor_info));
39    }
40
41    result.insert(
42        "model_type".to_string(),
43        Value::String("safetensors".to_string()),
44    );
45    result.insert("tensors".to_string(), Value::Object(tensors));
46
47    Ok(Value::Object(result))
48}
49
50fn compute_tensor_stats(tensor: &safetensors::tensor::TensorView) -> Option<Value> {
51    let data = tensor.data();
52    let dtype = tensor.dtype();
53
54    // Convert raw bytes to f64 values based on dtype
55    let values: Vec<f64> = match dtype {
56        Dtype::F32 => {
57            let floats: &[f32] = bytemuck::cast_slice(data);
58            floats.iter().map(|&x| x as f64).collect()
59        }
60        Dtype::F64 => {
61            let floats: &[f64] = bytemuck::cast_slice(data);
62            floats.to_vec()
63        }
64        Dtype::F16 => {
65            // F16 needs special handling - use half crate or convert manually
66            // For simplicity, skip F16 stats for now
67            return None;
68        }
69        Dtype::BF16 => {
70            // BF16 needs special handling
71            return None;
72        }
73        Dtype::I32 => {
74            let ints: &[i32] = bytemuck::cast_slice(data);
75            ints.iter().map(|&x| x as f64).collect()
76        }
77        Dtype::I64 => {
78            let ints: &[i64] = bytemuck::cast_slice(data);
79            ints.iter().map(|&x| x as f64).collect()
80        }
81        Dtype::I16 => {
82            let ints: &[i16] = bytemuck::cast_slice(data);
83            ints.iter().map(|&x| x as f64).collect()
84        }
85        Dtype::I8 => data.iter().map(|&x| x as i8 as f64).collect(),
86        Dtype::U8 => data.iter().map(|&x| x as f64).collect(),
87        Dtype::U16 => {
88            let ints: &[u16] = bytemuck::cast_slice(data);
89            ints.iter().map(|&x| x as f64).collect()
90        }
91        Dtype::U32 => {
92            let ints: &[u32] = bytemuck::cast_slice(data);
93            ints.iter().map(|&x| x as f64).collect()
94        }
95        Dtype::U64 => {
96            let ints: &[u64] = bytemuck::cast_slice(data);
97            ints.iter().map(|&x| x as f64).collect()
98        }
99        _ => return None,
100    };
101
102    if values.is_empty() {
103        return None;
104    }
105
106    let n = values.len() as f64;
107    let mean = values.iter().sum::<f64>() / n;
108    let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
109    let std = variance.sqrt();
110    let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
111    let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
112
113    Some(json!({
114        "mean": mean,
115        "std": std,
116        "min": min,
117        "max": max
118    }))
119}