ipfrs_interface/
safetensors.rs

1//! Safetensors integration for tensor serialization
2//!
3//! Provides utilities for working with safetensors format:
4//! - Parsing and validating safetensors files
5//! - Extracting tensor metadata
6//! - Reading tensor data
7//! - Creating safetensors from raw tensors
8
9use bytes::Bytes;
10use ipfrs_core::error::{Error, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Safetensors file format handler
15#[derive(Debug)]
16pub struct SafetensorsFile {
17    /// Parsed header with tensor metadata
18    header: SafetensorsHeader,
19    /// Raw data bytes (header + tensors)
20    data: Bytes,
21    /// Header size in bytes
22    header_size: usize,
23}
24
25/// Safetensors header structure
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct SafetensorsHeader {
28    /// Tensor metadata indexed by name
29    #[serde(flatten)]
30    pub tensors: HashMap<String, TensorInfo>,
31}
32
33/// Information about a single tensor
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct TensorInfo {
36    /// Data type (e.g., "F32", "F64", "I32")
37    pub dtype: String,
38    /// Tensor shape (dimensions)
39    pub shape: Vec<usize>,
40    /// Start offset in the data section
41    pub data_offsets: [usize; 2], // [start, end]
42}
43
44impl SafetensorsFile {
45    /// Parse a safetensors file from bytes
46    pub fn from_bytes(data: Bytes) -> Result<Self> {
47        if data.len() < 8 {
48            return Err(Error::InvalidInput(
49                "Data too short for safetensors format".to_string(),
50            ));
51        }
52
53        // First 8 bytes = header length (little-endian u64)
54        let header_len = u64::from_le_bytes(data[0..8].try_into().unwrap()) as usize;
55
56        if data.len() < 8 + header_len {
57            return Err(Error::InvalidInput(
58                "Incomplete safetensors header".to_string(),
59            ));
60        }
61
62        // Parse JSON header
63        let header_bytes = &data[8..8 + header_len];
64        let header: SafetensorsHeader = serde_json::from_slice(header_bytes).map_err(|e| {
65            Error::InvalidInput(format!("Failed to parse safetensors header: {}", e))
66        })?;
67
68        // Validate header
69        Self::validate_header(&header, data.len() - 8 - header_len)?;
70
71        Ok(SafetensorsFile {
72            header,
73            data,
74            header_size: 8 + header_len,
75        })
76    }
77
78    /// Validate header offsets and data integrity
79    fn validate_header(header: &SafetensorsHeader, data_section_size: usize) -> Result<()> {
80        for (name, info) in &header.tensors {
81            let [start, end] = info.data_offsets;
82
83            if start >= end {
84                return Err(Error::InvalidInput(format!(
85                    "Invalid offsets for tensor '{}': start={}, end={}",
86                    name, start, end
87                )));
88            }
89
90            if end > data_section_size {
91                return Err(Error::InvalidInput(format!(
92                    "Tensor '{}' offset {} exceeds data section size {}",
93                    name, end, data_section_size
94                )));
95            }
96
97            // Validate size matches shape and dtype
98            let expected_size = Self::calculate_tensor_size(&info.shape, &info.dtype);
99            let actual_size = end - start;
100
101            if actual_size != expected_size {
102                return Err(Error::InvalidInput(format!(
103                    "Tensor '{}' size mismatch: expected {}, got {}",
104                    name, expected_size, actual_size
105                )));
106            }
107        }
108
109        Ok(())
110    }
111
112    /// Calculate expected tensor size in bytes
113    fn calculate_tensor_size(shape: &[usize], dtype: &str) -> usize {
114        let num_elements: usize = shape.iter().product();
115        let element_size = Self::dtype_size(dtype);
116        num_elements * element_size
117    }
118
119    /// Get size of a data type in bytes
120    fn dtype_size(dtype: &str) -> usize {
121        match dtype {
122            "F16" | "BF16" => 2,
123            "F32" | "I32" | "U32" => 4,
124            "F64" | "I64" | "U64" => 8,
125            "I8" | "U8" => 1,
126            "I16" | "U16" => 2,
127            "BOOL" => 1,
128            _ => 4, // Default to 4 bytes
129        }
130    }
131
132    /// Get tensor data by name
133    pub fn get_tensor(&self, name: &str) -> Result<TensorData> {
134        let info = self.header.tensors.get(name).ok_or_else(|| {
135            Error::NotFound(format!("Tensor '{}' not found in safetensors file", name))
136        })?;
137
138        let [start, end] = info.data_offsets;
139        let data_start = self.header_size + start;
140        let data_end = self.header_size + end;
141
142        if data_end > self.data.len() {
143            return Err(Error::InvalidInput(format!(
144                "Tensor data range {}..{} exceeds file size {}",
145                data_start,
146                data_end,
147                self.data.len()
148            )));
149        }
150
151        Ok(TensorData {
152            dtype: info.dtype.clone(),
153            shape: info.shape.clone(),
154            data: self.data.slice(data_start..data_end),
155        })
156    }
157
158    /// Get all tensor names
159    pub fn tensor_names(&self) -> Vec<String> {
160        self.header
161            .tensors
162            .keys()
163            .filter(|k| k.as_str() != "__metadata__")
164            .cloned()
165            .collect()
166    }
167
168    /// Get tensor metadata by name
169    pub fn get_tensor_info(&self, name: &str) -> Option<&TensorInfo> {
170        self.header.tensors.get(name)
171    }
172
173    /// Get the full header
174    pub fn header(&self) -> &SafetensorsHeader {
175        &self.header
176    }
177
178    /// Get raw file data
179    pub fn raw_data(&self) -> &Bytes {
180        &self.data
181    }
182}
183
184/// Tensor data extracted from safetensors
185#[derive(Debug, Clone)]
186pub struct TensorData {
187    /// Data type
188    pub dtype: String,
189    /// Shape (dimensions)
190    pub shape: Vec<usize>,
191    /// Raw tensor data
192    pub data: Bytes,
193}
194
195impl TensorData {
196    /// Get the number of elements in the tensor
197    pub fn num_elements(&self) -> usize {
198        self.shape.iter().product()
199    }
200
201    /// Get the size in bytes
202    pub fn size_bytes(&self) -> usize {
203        self.data.len()
204    }
205
206    /// Get element size in bytes
207    pub fn element_size(&self) -> usize {
208        if self.num_elements() == 0 {
209            return 0;
210        }
211        self.size_bytes() / self.num_elements()
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_dtype_size() {
221        assert_eq!(SafetensorsFile::dtype_size("F32"), 4);
222        assert_eq!(SafetensorsFile::dtype_size("F64"), 8);
223        assert_eq!(SafetensorsFile::dtype_size("F16"), 2);
224        assert_eq!(SafetensorsFile::dtype_size("I32"), 4);
225        assert_eq!(SafetensorsFile::dtype_size("U8"), 1);
226        assert_eq!(SafetensorsFile::dtype_size("BOOL"), 1);
227    }
228
229    #[test]
230    fn test_calculate_tensor_size() {
231        assert_eq!(
232            SafetensorsFile::calculate_tensor_size(&[10, 20], "F32"),
233            10 * 20 * 4
234        );
235        assert_eq!(
236            SafetensorsFile::calculate_tensor_size(&[5, 5, 5], "F64"),
237            5 * 5 * 5 * 8
238        );
239    }
240
241    #[test]
242    fn test_tensor_data_num_elements() {
243        let data = TensorData {
244            dtype: "F32".to_string(),
245            shape: vec![2, 3],
246            data: Bytes::from(vec![0u8; 24]), // 2*3*4 = 24 bytes
247        };
248        assert_eq!(data.num_elements(), 6);
249        assert_eq!(data.size_bytes(), 24);
250        assert_eq!(data.element_size(), 4);
251    }
252}