ipfrs_interface/
safetensors.rs1use bytes::Bytes;
10use ipfrs_core::error::{Error, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[derive(Debug)]
16pub struct SafetensorsFile {
17 header: SafetensorsHeader,
19 data: Bytes,
21 header_size: usize,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct SafetensorsHeader {
28 #[serde(flatten)]
30 pub tensors: HashMap<String, TensorInfo>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct TensorInfo {
36 pub dtype: String,
38 pub shape: Vec<usize>,
40 pub data_offsets: [usize; 2], }
43
44impl SafetensorsFile {
45 pub fn from_bytes(data: Bytes) -> Result<Self> {
47 if data.len() < 8 {
48 return Err(Error::InvalidInput(
49 "Data too short for safetensors format".to_string(),
50 ));
51 }
52
53 let header_len = u64::from_le_bytes(data[0..8].try_into().unwrap()) as usize;
55
56 if data.len() < 8 + header_len {
57 return Err(Error::InvalidInput(
58 "Incomplete safetensors header".to_string(),
59 ));
60 }
61
62 let header_bytes = &data[8..8 + header_len];
64 let header: SafetensorsHeader = serde_json::from_slice(header_bytes).map_err(|e| {
65 Error::InvalidInput(format!("Failed to parse safetensors header: {}", e))
66 })?;
67
68 Self::validate_header(&header, data.len() - 8 - header_len)?;
70
71 Ok(SafetensorsFile {
72 header,
73 data,
74 header_size: 8 + header_len,
75 })
76 }
77
78 fn validate_header(header: &SafetensorsHeader, data_section_size: usize) -> Result<()> {
80 for (name, info) in &header.tensors {
81 let [start, end] = info.data_offsets;
82
83 if start >= end {
84 return Err(Error::InvalidInput(format!(
85 "Invalid offsets for tensor '{}': start={}, end={}",
86 name, start, end
87 )));
88 }
89
90 if end > data_section_size {
91 return Err(Error::InvalidInput(format!(
92 "Tensor '{}' offset {} exceeds data section size {}",
93 name, end, data_section_size
94 )));
95 }
96
97 let expected_size = Self::calculate_tensor_size(&info.shape, &info.dtype);
99 let actual_size = end - start;
100
101 if actual_size != expected_size {
102 return Err(Error::InvalidInput(format!(
103 "Tensor '{}' size mismatch: expected {}, got {}",
104 name, expected_size, actual_size
105 )));
106 }
107 }
108
109 Ok(())
110 }
111
112 fn calculate_tensor_size(shape: &[usize], dtype: &str) -> usize {
114 let num_elements: usize = shape.iter().product();
115 let element_size = Self::dtype_size(dtype);
116 num_elements * element_size
117 }
118
119 fn dtype_size(dtype: &str) -> usize {
121 match dtype {
122 "F16" | "BF16" => 2,
123 "F32" | "I32" | "U32" => 4,
124 "F64" | "I64" | "U64" => 8,
125 "I8" | "U8" => 1,
126 "I16" | "U16" => 2,
127 "BOOL" => 1,
128 _ => 4, }
130 }
131
132 pub fn get_tensor(&self, name: &str) -> Result<TensorData> {
134 let info = self.header.tensors.get(name).ok_or_else(|| {
135 Error::NotFound(format!("Tensor '{}' not found in safetensors file", name))
136 })?;
137
138 let [start, end] = info.data_offsets;
139 let data_start = self.header_size + start;
140 let data_end = self.header_size + end;
141
142 if data_end > self.data.len() {
143 return Err(Error::InvalidInput(format!(
144 "Tensor data range {}..{} exceeds file size {}",
145 data_start,
146 data_end,
147 self.data.len()
148 )));
149 }
150
151 Ok(TensorData {
152 dtype: info.dtype.clone(),
153 shape: info.shape.clone(),
154 data: self.data.slice(data_start..data_end),
155 })
156 }
157
158 pub fn tensor_names(&self) -> Vec<String> {
160 self.header
161 .tensors
162 .keys()
163 .filter(|k| k.as_str() != "__metadata__")
164 .cloned()
165 .collect()
166 }
167
168 pub fn get_tensor_info(&self, name: &str) -> Option<&TensorInfo> {
170 self.header.tensors.get(name)
171 }
172
173 pub fn header(&self) -> &SafetensorsHeader {
175 &self.header
176 }
177
178 pub fn raw_data(&self) -> &Bytes {
180 &self.data
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct TensorData {
187 pub dtype: String,
189 pub shape: Vec<usize>,
191 pub data: Bytes,
193}
194
195impl TensorData {
196 pub fn num_elements(&self) -> usize {
198 self.shape.iter().product()
199 }
200
201 pub fn size_bytes(&self) -> usize {
203 self.data.len()
204 }
205
206 pub fn element_size(&self) -> usize {
208 if self.num_elements() == 0 {
209 return 0;
210 }
211 self.size_bytes() / self.num_elements()
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_dtype_size() {
221 assert_eq!(SafetensorsFile::dtype_size("F32"), 4);
222 assert_eq!(SafetensorsFile::dtype_size("F64"), 8);
223 assert_eq!(SafetensorsFile::dtype_size("F16"), 2);
224 assert_eq!(SafetensorsFile::dtype_size("I32"), 4);
225 assert_eq!(SafetensorsFile::dtype_size("U8"), 1);
226 assert_eq!(SafetensorsFile::dtype_size("BOOL"), 1);
227 }
228
229 #[test]
230 fn test_calculate_tensor_size() {
231 assert_eq!(
232 SafetensorsFile::calculate_tensor_size(&[10, 20], "F32"),
233 10 * 20 * 4
234 );
235 assert_eq!(
236 SafetensorsFile::calculate_tensor_size(&[5, 5, 5], "F64"),
237 5 * 5 * 5 * 8
238 );
239 }
240
241 #[test]
242 fn test_tensor_data_num_elements() {
243 let data = TensorData {
244 dtype: "F32".to_string(),
245 shape: vec![2, 3],
246 data: Bytes::from(vec![0u8; 24]), };
248 assert_eq!(data.num_elements(), 6);
249 assert_eq!(data.size_bytes(), 24);
250 assert_eq!(data.element_size(), 4);
251 }
252}