diffai_core/parsers/
safetensors.rs1use anyhow::Result;
2use safetensors::{Dtype, SafeTensors};
3use serde_json::{json, Value};
4use std::path::Path;
5
6pub fn parse_safetensors_model(file_path: &Path) -> Result<Value> {
8 let buffer = std::fs::read(file_path)?;
9 let safetensors = SafeTensors::deserialize(&buffer)?;
10
11 let mut result = serde_json::Map::new();
12 let mut tensors = serde_json::Map::new();
13
14 for tensor_name in safetensors.names() {
15 let tensor_view = safetensors.tensor(tensor_name)?;
16 let mut tensor_info = serde_json::Map::new();
17
18 tensor_info.insert(
19 "shape".to_string(),
20 Value::Array(
21 tensor_view
22 .shape()
23 .iter()
24 .map(|&s| Value::Number(s.into()))
25 .collect(),
26 ),
27 );
28 tensor_info.insert(
29 "dtype".to_string(),
30 Value::String(format!("{:?}", tensor_view.dtype())),
31 );
32
33 if let Some(stats) = compute_tensor_stats(&tensor_view) {
35 tensor_info.insert("data_summary".to_string(), stats);
36 }
37
38 tensors.insert(tensor_name.to_string(), Value::Object(tensor_info));
39 }
40
41 result.insert(
42 "model_type".to_string(),
43 Value::String("safetensors".to_string()),
44 );
45 result.insert("tensors".to_string(), Value::Object(tensors));
46
47 Ok(Value::Object(result))
48}
49
50fn compute_tensor_stats(tensor: &safetensors::tensor::TensorView) -> Option<Value> {
51 let data = tensor.data();
52 let dtype = tensor.dtype();
53
54 let values: Vec<f64> = match dtype {
56 Dtype::F32 => {
57 let floats: &[f32] = bytemuck::cast_slice(data);
58 floats.iter().map(|&x| x as f64).collect()
59 }
60 Dtype::F64 => {
61 let floats: &[f64] = bytemuck::cast_slice(data);
62 floats.to_vec()
63 }
64 Dtype::F16 => {
65 return None;
68 }
69 Dtype::BF16 => {
70 return None;
72 }
73 Dtype::I32 => {
74 let ints: &[i32] = bytemuck::cast_slice(data);
75 ints.iter().map(|&x| x as f64).collect()
76 }
77 Dtype::I64 => {
78 let ints: &[i64] = bytemuck::cast_slice(data);
79 ints.iter().map(|&x| x as f64).collect()
80 }
81 Dtype::I16 => {
82 let ints: &[i16] = bytemuck::cast_slice(data);
83 ints.iter().map(|&x| x as f64).collect()
84 }
85 Dtype::I8 => data.iter().map(|&x| x as i8 as f64).collect(),
86 Dtype::U8 => data.iter().map(|&x| x as f64).collect(),
87 Dtype::U16 => {
88 let ints: &[u16] = bytemuck::cast_slice(data);
89 ints.iter().map(|&x| x as f64).collect()
90 }
91 Dtype::U32 => {
92 let ints: &[u32] = bytemuck::cast_slice(data);
93 ints.iter().map(|&x| x as f64).collect()
94 }
95 Dtype::U64 => {
96 let ints: &[u64] = bytemuck::cast_slice(data);
97 ints.iter().map(|&x| x as f64).collect()
98 }
99 _ => return None,
100 };
101
102 if values.is_empty() {
103 return None;
104 }
105
106 let n = values.len() as f64;
107 let mean = values.iter().sum::<f64>() / n;
108 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
109 let std = variance.sqrt();
110 let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
111 let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
112
113 Some(json!({
114 "mean": mean,
115 "std": std,
116 "min": min,
117 "max": max
118 }))
119}