ipfrs_storage/
safetensors.rs

1//! Safetensors format support for efficient model storage
2//!
3//! Provides native support for the Safetensors format:
4//! - Parse .safetensors files
5//! - Extract metadata and tensor information
6//! - Store tensors as content-addressed blocks
7//! - Chunked storage for large models (70B+ parameters)
8//! - Lazy loading of model weights
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! use ipfrs_storage::{SafetensorsStore, SledBlockStore, BlockStoreConfig};
14//! use std::sync::Arc;
15//! use std::path::PathBuf;
16//!
17//! # async fn example() -> ipfrs_core::Result<()> {
18//! // Create block store
19//! let store = Arc::new(SledBlockStore::new(BlockStoreConfig {
20//!     path: PathBuf::from(".ipfrs/models"),
21//!     cache_size: 1024 * 1024 * 1024, // 1GB cache
22//! })?);
23//!
24//! // Create safetensors store
25//! let safetensors_store = SafetensorsStore::new(store);
26//!
27//! // Load and store a safetensors file
28//! let model_cid = safetensors_store.import_file("model.safetensors").await?;
29//!
30//! // Lazy load a specific tensor
31//! let tensor_data = safetensors_store.load_tensor(&model_cid, "layer.0.weight").await?;
32//! # Ok(())
33//! # }
34//! ```
35
36use crate::traits::BlockStore;
37use bytes::Bytes;
38use ipfrs_core::{Block, Cid, Error, Result};
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use std::str::FromStr;
42use std::sync::Arc;
43
44/// Tensor data type
45#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
46pub enum DType {
47    F32,
48    F64,
49    F16,
50    BF16,
51    I8,
52    I16,
53    I32,
54    I64,
55    U8,
56    U16,
57    U32,
58    U64,
59    Bool,
60}
61
62impl DType {
63    /// Get size in bytes for this dtype
64    pub fn size(&self) -> usize {
65        match self {
66            DType::F32 | DType::I32 | DType::U32 => 4,
67            DType::F64 | DType::I64 | DType::U64 => 8,
68            DType::F16 | DType::BF16 | DType::I16 | DType::U16 => 2,
69            DType::I8 | DType::U8 | DType::Bool => 1,
70        }
71    }
72}
73
74impl FromStr for DType {
75    type Err = String;
76
77    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
78        match s {
79            "F32" => Ok(DType::F32),
80            "F64" => Ok(DType::F64),
81            "F16" => Ok(DType::F16),
82            "BF16" => Ok(DType::BF16),
83            "I8" => Ok(DType::I8),
84            "I16" => Ok(DType::I16),
85            "I32" => Ok(DType::I32),
86            "I64" => Ok(DType::I64),
87            "U8" => Ok(DType::U8),
88            "U16" => Ok(DType::U16),
89            "U32" => Ok(DType::U32),
90            "U64" => Ok(DType::U64),
91            "BOOL" => Ok(DType::Bool),
92            _ => Err(format!("Unknown dtype: {s}")),
93        }
94    }
95}
96
97/// Tensor metadata from safetensors header
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99pub struct TensorInfo {
100    /// Data type of the tensor
101    pub dtype: DType,
102    /// Shape of the tensor
103    pub shape: Vec<usize>,
104    /// Start offset in the data section
105    pub data_offsets: (usize, usize),
106}
107
108impl TensorInfo {
109    /// Calculate total number of elements
110    pub fn numel(&self) -> usize {
111        self.shape.iter().product()
112    }
113
114    /// Calculate total size in bytes
115    pub fn size_bytes(&self) -> usize {
116        self.numel() * self.dtype.size()
117    }
118}
119
120/// Safetensors file header
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct SafetensorsHeader {
123    /// Tensor metadata by name
124    pub tensors: HashMap<String, TensorInfo>,
125    /// Additional metadata
126    pub metadata: HashMap<String, String>,
127}
128
129/// Chunked tensor storage for large tensors
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ChunkedTensor {
132    /// Tensor name
133    pub name: String,
134    /// Tensor metadata
135    pub info: TensorInfo,
136    /// CIDs of chunks (in order)
137    #[serde(
138        serialize_with = "serialize_cid_vec",
139        deserialize_with = "deserialize_cid_vec"
140    )]
141    pub chunk_cids: Vec<Cid>,
142    /// Size of each chunk in bytes
143    pub chunk_size: usize,
144}
145
146// Custom serialization for Vec<Cid>
147fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> std::result::Result<S::Ok, S::Error>
148where
149    S: serde::Serializer,
150{
151    use serde::ser::SerializeSeq;
152    let mut seq = serializer.serialize_seq(Some(cids.len()))?;
153    for cid in cids {
154        seq.serialize_element(&cid.to_bytes())?;
155    }
156    seq.end()
157}
158
159fn deserialize_cid_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<Cid>, D::Error>
160where
161    D: serde::Deserializer<'de>,
162{
163    let bytes_vec: Vec<Vec<u8>> = Deserialize::deserialize(deserializer)?;
164    bytes_vec
165        .into_iter()
166        .map(|bytes| Cid::try_from(bytes).map_err(serde::de::Error::custom))
167        .collect()
168}
169
170/// Safetensors model manifest
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct SafetensorsManifest {
173    /// Model name
174    pub name: String,
175    /// Safetensors header
176    pub header: SafetensorsHeader,
177    /// Chunked tensors
178    pub tensors: HashMap<String, ChunkedTensor>,
179    /// Total model size in bytes
180    pub total_size: u64,
181}
182
183/// Configuration for chunked storage
184#[derive(Debug, Clone)]
185pub struct ChunkConfig {
186    /// Chunk size in bytes (default: 64MB)
187    pub chunk_size: usize,
188    /// Whether to compress chunks
189    pub compress: bool,
190}
191
192impl Default for ChunkConfig {
193    fn default() -> Self {
194        Self {
195            chunk_size: 64 * 1024 * 1024, // 64MB
196            compress: false,
197        }
198    }
199}
200
201/// Safetensors store for managing model weights
202pub struct SafetensorsStore<S: BlockStore> {
203    /// Underlying block store
204    store: Arc<S>,
205    /// Chunk configuration
206    chunk_config: ChunkConfig,
207}
208
209impl<S: BlockStore> SafetensorsStore<S> {
210    /// Create a new safetensors store
211    pub fn new(store: Arc<S>) -> Self {
212        Self {
213            store,
214            chunk_config: ChunkConfig::default(),
215        }
216    }
217
218    /// Create with custom chunk configuration
219    pub fn with_config(store: Arc<S>, chunk_config: ChunkConfig) -> Self {
220        Self {
221            store,
222            chunk_config,
223        }
224    }
225
226    /// Parse safetensors header from bytes
227    pub fn parse_header(data: &[u8]) -> Result<(SafetensorsHeader, usize)> {
228        if data.len() < 8 {
229            return Err(Error::Storage(
230                "File too small to be safetensors".to_string(),
231            ));
232        }
233
234        // Read header size (8 bytes, little-endian u64)
235        let header_size = u64::from_le_bytes([
236            data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
237        ]) as usize;
238
239        if data.len() < 8 + header_size {
240            return Err(Error::Storage("Incomplete safetensors header".to_string()));
241        }
242
243        // Parse JSON header
244        let header_bytes = &data[8..8 + header_size];
245        let header_json: serde_json::Value = serde_json::from_slice(header_bytes)
246            .map_err(|e| Error::Serialization(format!("Failed to parse header JSON: {e}")))?;
247
248        let mut tensors = HashMap::new();
249        let mut metadata = HashMap::new();
250
251        // Parse tensors
252        if let Some(obj) = header_json.as_object() {
253            for (key, value) in obj {
254                if key == "__metadata__" {
255                    // Parse metadata
256                    if let Some(meta_obj) = value.as_object() {
257                        for (k, v) in meta_obj {
258                            if let Some(s) = v.as_str() {
259                                metadata.insert(k.clone(), s.to_string());
260                            }
261                        }
262                    }
263                } else {
264                    // Parse tensor info
265                    if let Some(tensor_obj) = value.as_object() {
266                        let dtype_str = tensor_obj
267                            .get("dtype")
268                            .and_then(|v| v.as_str())
269                            .ok_or_else(|| Error::Storage("Missing dtype".to_string()))?;
270
271                        let dtype = dtype_str.parse::<DType>().map_err(Error::Storage)?;
272
273                        let shape: Vec<usize> = tensor_obj
274                            .get("shape")
275                            .and_then(|v| v.as_array())
276                            .ok_or_else(|| Error::Storage("Missing shape".to_string()))?
277                            .iter()
278                            .filter_map(|v| v.as_u64().map(|n| n as usize))
279                            .collect();
280
281                        let data_offsets = tensor_obj
282                            .get("data_offsets")
283                            .and_then(|v| v.as_array())
284                            .ok_or_else(|| Error::Storage("Missing data_offsets".to_string()))?;
285
286                        let start = data_offsets[0].as_u64().ok_or_else(|| {
287                            Error::Storage("Invalid data_offsets start".to_string())
288                        })? as usize;
289                        let end = data_offsets[1]
290                            .as_u64()
291                            .ok_or_else(|| Error::Storage("Invalid data_offsets end".to_string()))?
292                            as usize;
293
294                        tensors.insert(
295                            key.clone(),
296                            TensorInfo {
297                                dtype,
298                                shape,
299                                data_offsets: (start, end),
300                            },
301                        );
302                    }
303                }
304            }
305        }
306
307        Ok((SafetensorsHeader { tensors, metadata }, 8 + header_size))
308    }
309
310    /// Import safetensors file and store as chunks
311    pub async fn import_from_bytes(&self, name: String, data: &[u8]) -> Result<Cid> {
312        // Parse header
313        let (header, data_offset) = Self::parse_header(data)?;
314
315        let data_section = &data[data_offset..];
316        let mut chunked_tensors = HashMap::new();
317        let mut total_size = 0u64;
318
319        // Process each tensor
320        for (tensor_name, tensor_info) in &header.tensors {
321            let (start, end) = tensor_info.data_offsets;
322            let tensor_data = &data_section[start..end];
323
324            // Chunk the tensor data
325            let mut chunk_cids = Vec::new();
326            for chunk in tensor_data.chunks(self.chunk_config.chunk_size) {
327                let block = Block::new(Bytes::from(chunk.to_vec()))?;
328                let cid = *block.cid();
329                self.store.put(&block).await?;
330                chunk_cids.push(cid);
331            }
332
333            chunked_tensors.insert(
334                tensor_name.clone(),
335                ChunkedTensor {
336                    name: tensor_name.clone(),
337                    info: tensor_info.clone(),
338                    chunk_cids,
339                    chunk_size: self.chunk_config.chunk_size,
340                },
341            );
342
343            total_size += tensor_data.len() as u64;
344        }
345
346        // Create manifest
347        let manifest = SafetensorsManifest {
348            name,
349            header,
350            tensors: chunked_tensors,
351            total_size,
352        };
353
354        // Store manifest
355        let manifest_bytes = oxicode::serde::encode_to_vec(&manifest, oxicode::config::standard())
356            .map_err(|e| Error::Serialization(format!("Failed to serialize manifest: {e}")))?;
357
358        let manifest_block = Block::new(Bytes::from(manifest_bytes))?;
359        let manifest_cid = *manifest_block.cid();
360        self.store.put(&manifest_block).await?;
361
362        Ok(manifest_cid)
363    }
364
365    /// Load safetensors manifest
366    pub async fn load_manifest(&self, manifest_cid: &Cid) -> Result<SafetensorsManifest> {
367        let block = self
368            .store
369            .get(manifest_cid)
370            .await?
371            .ok_or_else(|| Error::NotFound(format!("Manifest not found: {manifest_cid}")))?;
372
373        let manifest: SafetensorsManifest =
374            oxicode::serde::decode_owned_from_slice(block.data(), oxicode::config::standard())
375                .map(|(v, _)| v)
376                .map_err(|e| {
377                    Error::Serialization(format!("Failed to deserialize manifest: {e}"))
378                })?;
379
380        Ok(manifest)
381    }
382
383    /// Load a specific tensor (lazy loading)
384    pub async fn load_tensor(&self, manifest_cid: &Cid, tensor_name: &str) -> Result<Vec<u8>> {
385        let manifest = self.load_manifest(manifest_cid).await?;
386
387        let chunked_tensor = manifest
388            .tensors
389            .get(tensor_name)
390            .ok_or_else(|| Error::NotFound(format!("Tensor not found: {tensor_name}")))?;
391
392        // Load all chunks
393        let mut tensor_data = Vec::with_capacity(chunked_tensor.info.size_bytes());
394
395        for chunk_cid in &chunked_tensor.chunk_cids {
396            let chunk_block = self
397                .store
398                .get(chunk_cid)
399                .await?
400                .ok_or_else(|| Error::NotFound(format!("Chunk not found: {chunk_cid}")))?;
401
402            tensor_data.extend_from_slice(chunk_block.data());
403        }
404
405        Ok(tensor_data)
406    }
407
408    /// Load multiple tensors (batch loading for efficiency)
409    pub async fn load_tensors(
410        &self,
411        manifest_cid: &Cid,
412        tensor_names: &[&str],
413    ) -> Result<HashMap<String, Vec<u8>>> {
414        let _manifest = self.load_manifest(manifest_cid).await?;
415        let mut result = HashMap::new();
416
417        for &tensor_name in tensor_names {
418            let tensor_data = self.load_tensor(manifest_cid, tensor_name).await?;
419            result.insert(tensor_name.to_string(), tensor_data);
420        }
421
422        Ok(result)
423    }
424
425    /// Get tensor metadata without loading data
426    pub async fn get_tensor_info(
427        &self,
428        manifest_cid: &Cid,
429        tensor_name: &str,
430    ) -> Result<TensorInfo> {
431        let manifest = self.load_manifest(manifest_cid).await?;
432
433        manifest
434            .tensors
435            .get(tensor_name)
436            .map(|ct| ct.info.clone())
437            .ok_or_else(|| Error::NotFound(format!("Tensor not found: {tensor_name}")))
438    }
439
440    /// List all tensors in the model
441    pub async fn list_tensors(&self, manifest_cid: &Cid) -> Result<Vec<String>> {
442        let manifest = self.load_manifest(manifest_cid).await?;
443        Ok(manifest.tensors.keys().cloned().collect())
444    }
445
446    /// Get model statistics
447    pub async fn get_model_stats(&self, manifest_cid: &Cid) -> Result<ModelStats> {
448        let manifest = self.load_manifest(manifest_cid).await?;
449
450        let tensor_count = manifest.tensors.len();
451        let total_parameters: usize = manifest.tensors.values().map(|ct| ct.info.numel()).sum();
452
453        let chunk_count: usize = manifest
454            .tensors
455            .values()
456            .map(|ct| ct.chunk_cids.len())
457            .sum();
458
459        Ok(ModelStats {
460            name: manifest.name,
461            tensor_count,
462            total_parameters,
463            total_size_bytes: manifest.total_size,
464            chunk_count,
465            avg_chunk_size: if chunk_count > 0 {
466                manifest.total_size / chunk_count as u64
467            } else {
468                0
469            },
470        })
471    }
472}
473
474/// Model statistics
475#[derive(Debug, Clone, PartialEq, Eq)]
476pub struct ModelStats {
477    /// Model name
478    pub name: String,
479    /// Number of tensors
480    pub tensor_count: usize,
481    /// Total number of parameters
482    pub total_parameters: usize,
483    /// Total size in bytes
484    pub total_size_bytes: u64,
485    /// Number of chunks
486    pub chunk_count: usize,
487    /// Average chunk size
488    pub avg_chunk_size: u64,
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494    use crate::blockstore::{BlockStoreConfig, SledBlockStore};
495    use std::path::PathBuf;
496
497    #[test]
498    fn test_dtype_size() {
499        assert_eq!(DType::F32.size(), 4);
500        assert_eq!(DType::F64.size(), 8);
501        assert_eq!(DType::F16.size(), 2);
502        assert_eq!(DType::I8.size(), 1);
503    }
504
505    #[test]
506    fn test_tensor_info_numel() {
507        let info = TensorInfo {
508            dtype: DType::F32,
509            shape: vec![2, 3, 4],
510            data_offsets: (0, 96),
511        };
512
513        assert_eq!(info.numel(), 24);
514        assert_eq!(info.size_bytes(), 96);
515    }
516
517    #[tokio::test]
518    async fn test_safetensors_store() {
519        let config = BlockStoreConfig {
520            path: PathBuf::from("/tmp/ipfrs-safetensors-test"),
521            cache_size: 100 * 1024 * 1024,
522        };
523        let _ = std::fs::remove_dir_all(&config.path);
524
525        let store = Arc::new(SledBlockStore::new(config).unwrap());
526        let safetensors_store = SafetensorsStore::new(store);
527
528        // Create a minimal safetensors file
529        let header = r#"{"tensor1":{"dtype":"F32","shape":[2,2],"data_offsets":[0,16]}}"#;
530        let header_size = header.len() as u64;
531        let mut data = Vec::new();
532        data.extend_from_slice(&header_size.to_le_bytes());
533        data.extend_from_slice(header.as_bytes());
534        // Add tensor data (2x2 f32 = 16 bytes)
535        data.extend_from_slice(&[0u8; 16]);
536
537        let manifest_cid = safetensors_store
538            .import_from_bytes("test_model".to_string(), &data)
539            .await
540            .unwrap();
541
542        // Load manifest
543        let manifest = safetensors_store
544            .load_manifest(&manifest_cid)
545            .await
546            .unwrap();
547        assert_eq!(manifest.name, "test_model");
548        assert_eq!(manifest.tensors.len(), 1);
549
550        // Get stats
551        let stats = safetensors_store
552            .get_model_stats(&manifest_cid)
553            .await
554            .unwrap();
555        assert_eq!(stats.tensor_count, 1);
556        assert_eq!(stats.total_parameters, 4);
557    }
558
559    #[test]
560    fn test_parse_header() {
561        let header = r#"{"tensor1":{"dtype":"F32","shape":[2,2],"data_offsets":[0,16]}}"#;
562        let header_size = header.len() as u64;
563        let mut data = Vec::new();
564        data.extend_from_slice(&header_size.to_le_bytes());
565        data.extend_from_slice(header.as_bytes());
566
567        let (parsed, offset) = SafetensorsStore::<SledBlockStore>::parse_header(&data).unwrap();
568        assert_eq!(offset, 8 + header.len());
569        assert_eq!(parsed.tensors.len(), 1);
570        assert!(parsed.tensors.contains_key("tensor1"));
571
572        let tensor_info = &parsed.tensors["tensor1"];
573        assert_eq!(tensor_info.dtype, DType::F32);
574        assert_eq!(tensor_info.shape, vec![2, 2]);
575        assert_eq!(tensor_info.data_offsets, (0, 16));
576    }
577}