ipfrs_tensorlogic/
safetensors_support.rs

1//! Safetensors file format support
2//!
3//! Provides parsing and writing of the safetensors format for:
4//! - Native safetensors reading with multi-dtype support (f32, f64, i32, i64)
5//! - Chunked storage for large models
6//! - Lazy loading with memory mapping
7//! - Metadata extraction
8//! - Zero-copy Arrow tensor conversion
9//!
10//! ## Supported Data Types
11//!
12//! The reader and writer support the following data types:
13//! - **Float32** (f32) - Standard precision floating point
14//! - **Float64** (f64) - Double precision floating point
15//! - **Int32** (i32) - 32-bit signed integers
16//! - **Int64** (i64) - 64-bit signed integers
17//!
18//! All supported types can be loaded as Arrow tensors for zero-copy access.
19
20use crate::arrow::{ArrowTensor, ArrowTensorStore, TensorDtype};
21use bytes::Bytes;
22use memmap2::Mmap;
23use safetensors::tensor::{SafeTensorError, SafeTensors};
24use safetensors::{Dtype, View};
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::fs::File;
28use std::io::Write;
29use std::path::Path;
30
31/// Safetensors file reader with lazy loading support
32pub struct SafetensorsReader {
33    /// Memory-mapped file (for lazy loading)
34    mmap: Option<Mmap>,
35    /// Raw bytes (for in-memory loading)
36    bytes: Option<Bytes>,
37    /// Parsed tensor metadata
38    metadata: HashMap<String, TensorInfo>,
39    /// Global metadata from the file
40    global_metadata: HashMap<String, String>,
41}
42
43/// Information about a tensor in a safetensors file
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct TensorInfo {
46    /// Tensor name
47    pub name: String,
48    /// Data type
49    pub dtype: TensorDtype,
50    /// Shape dimensions
51    pub shape: Vec<usize>,
52    /// Byte offset in the file
53    pub data_offset: usize,
54    /// Size in bytes
55    pub data_size: usize,
56}
57
58impl SafetensorsReader {
59    /// Open a safetensors file with memory mapping (lazy loading)
60    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, SafetensorError> {
61        let file = File::open(path.as_ref()).map_err(SafetensorError::Io)?;
62        let mmap = unsafe { Mmap::map(&file).map_err(SafetensorError::Io)? };
63
64        Self::from_mmap(mmap)
65    }
66
67    /// Create from memory-mapped data
68    fn from_mmap(mmap: Mmap) -> Result<Self, SafetensorError> {
69        // Parse header to get metadata
70        let tensors = SafeTensors::deserialize(&mmap)?;
71
72        let mut metadata = HashMap::new();
73        let global_metadata = HashMap::new();
74
75        // Extract tensor info
76        for (name, view) in tensors.tensors() {
77            let dtype = convert_safetensor_dtype(view.dtype());
78            let shape = view.shape().to_vec();
79            let data = view.data();
80
81            let info = TensorInfo {
82                name: name.clone(),
83                dtype,
84                shape,
85                data_offset: data.as_ptr() as usize - mmap.as_ptr() as usize,
86                data_size: data.len(),
87            };
88            metadata.insert(name, info);
89        }
90
91        Ok(Self {
92            mmap: Some(mmap),
93            bytes: None,
94            metadata,
95            global_metadata,
96        })
97    }
98
99    /// Load from bytes
100    pub fn from_bytes(bytes: Bytes) -> Result<Self, SafetensorError> {
101        let tensors = SafeTensors::deserialize(&bytes)?;
102
103        let mut metadata = HashMap::new();
104        let global_metadata = HashMap::new();
105
106        for (name, view) in tensors.tensors() {
107            let dtype = convert_safetensor_dtype(view.dtype());
108            let shape = view.shape().to_vec();
109            let data = view.data();
110
111            let info = TensorInfo {
112                name: name.clone(),
113                dtype,
114                shape,
115                data_offset: data.as_ptr() as usize - bytes.as_ptr() as usize,
116                data_size: data.len(),
117            };
118            metadata.insert(name, info);
119        }
120
121        Ok(Self {
122            mmap: None,
123            bytes: Some(bytes),
124            metadata,
125            global_metadata,
126        })
127    }
128
129    /// Get all tensor names
130    pub fn tensor_names(&self) -> Vec<&str> {
131        self.metadata.keys().map(|s| s.as_str()).collect()
132    }
133
134    /// Get tensor info by name
135    pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
136        self.metadata.get(name)
137    }
138
139    /// Get global metadata
140    pub fn global_metadata(&self) -> &HashMap<String, String> {
141        &self.global_metadata
142    }
143
144    /// Get the number of tensors
145    pub fn len(&self) -> usize {
146        self.metadata.len()
147    }
148
149    /// Check if empty
150    pub fn is_empty(&self) -> bool {
151        self.metadata.is_empty()
152    }
153
154    /// Get raw data for a tensor (zero-copy)
155    pub fn tensor_data(&self, name: &str) -> Option<&[u8]> {
156        let info = self.metadata.get(name)?;
157        let data = self.get_data()?;
158        Some(&data[info.data_offset..info.data_offset + info.data_size])
159    }
160
161    /// Get the underlying data slice
162    fn get_data(&self) -> Option<&[u8]> {
163        if let Some(ref mmap) = self.mmap {
164            Some(mmap.as_ref())
165        } else if let Some(ref bytes) = self.bytes {
166            Some(bytes.as_ref())
167        } else {
168            None
169        }
170    }
171
172    /// Load a tensor as f32 slice
173    pub fn load_f32(&self, name: &str) -> Option<Vec<f32>> {
174        let info = self.tensor_info(name)?;
175        if info.dtype != TensorDtype::Float32 {
176            return None;
177        }
178
179        let data = self.tensor_data(name)?;
180        let f32_data: Vec<f32> = data
181            .chunks_exact(4)
182            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
183            .collect();
184        Some(f32_data)
185    }
186
187    /// Load a tensor as f64 slice
188    pub fn load_f64(&self, name: &str) -> Option<Vec<f64>> {
189        let info = self.tensor_info(name)?;
190        if info.dtype != TensorDtype::Float64 {
191            return None;
192        }
193
194        let data = self.tensor_data(name)?;
195        let f64_data: Vec<f64> = data
196            .chunks_exact(8)
197            .map(|chunk| {
198                f64::from_le_bytes([
199                    chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
200                ])
201            })
202            .collect();
203        Some(f64_data)
204    }
205
206    /// Load a tensor as i32 slice
207    pub fn load_i32(&self, name: &str) -> Option<Vec<i32>> {
208        let info = self.tensor_info(name)?;
209        if info.dtype != TensorDtype::Int32 {
210            return None;
211        }
212
213        let data = self.tensor_data(name)?;
214        let i32_data: Vec<i32> = data
215            .chunks_exact(4)
216            .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
217            .collect();
218        Some(i32_data)
219    }
220
221    /// Load a tensor as i64 slice
222    pub fn load_i64(&self, name: &str) -> Option<Vec<i64>> {
223        let info = self.tensor_info(name)?;
224        if info.dtype != TensorDtype::Int64 {
225            return None;
226        }
227
228        let data = self.tensor_data(name)?;
229        let i64_data: Vec<i64> = data
230            .chunks_exact(8)
231            .map(|chunk| {
232                i64::from_le_bytes([
233                    chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
234                ])
235            })
236            .collect();
237        Some(i64_data)
238    }
239
240    /// Load a tensor as ArrowTensor
241    pub fn load_as_arrow(&self, name: &str) -> Option<ArrowTensor> {
242        let info = self.tensor_info(name)?;
243
244        match info.dtype {
245            TensorDtype::Float32 => {
246                let data = self.load_f32(name)?;
247                Some(ArrowTensor::from_slice_f32(name, info.shape.clone(), &data))
248            }
249            TensorDtype::Float64 => {
250                let data = self.load_f64(name)?;
251                Some(ArrowTensor::from_slice_f64(name, info.shape.clone(), &data))
252            }
253            TensorDtype::Int32 => {
254                let data = self.load_i32(name)?;
255                Some(ArrowTensor::from_slice_i32(name, info.shape.clone(), &data))
256            }
257            TensorDtype::Int64 => {
258                let data = self.load_i64(name)?;
259                Some(ArrowTensor::from_slice_i64(name, info.shape.clone(), &data))
260            }
261            _ => None, // Other dtypes not yet supported in ArrowTensor
262        }
263    }
264
265    /// Load all tensors into an ArrowTensorStore
266    pub fn load_all_as_arrow(&self) -> ArrowTensorStore {
267        let mut store = ArrowTensorStore::new();
268
269        for name in self.tensor_names() {
270            if let Some(tensor) = self.load_as_arrow(name) {
271                store.insert(tensor);
272            }
273        }
274
275        store
276    }
277
278    /// Get total size of all tensors
279    pub fn total_size_bytes(&self) -> usize {
280        self.metadata.values().map(|info| info.data_size).sum()
281    }
282
283    /// Get a summary of the model
284    pub fn summary(&self) -> ModelSummary {
285        let mut dtype_counts: HashMap<TensorDtype, usize> = HashMap::new();
286        let mut total_params = 0usize;
287        let mut total_bytes = 0usize;
288
289        for info in self.metadata.values() {
290            *dtype_counts.entry(info.dtype).or_insert(0) += 1;
291            let numel: usize = info.shape.iter().product();
292            total_params += numel;
293            total_bytes += info.data_size;
294        }
295
296        ModelSummary {
297            num_tensors: self.metadata.len(),
298            total_params,
299            total_bytes,
300            dtype_distribution: dtype_counts,
301            metadata: self.global_metadata.clone(),
302        }
303    }
304}
305
306/// Summary of a model's structure
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct ModelSummary {
309    /// Number of tensors
310    pub num_tensors: usize,
311    /// Total number of parameters
312    pub total_params: usize,
313    /// Total size in bytes
314    pub total_bytes: usize,
315    /// Distribution of data types
316    pub dtype_distribution: HashMap<TensorDtype, usize>,
317    /// Global metadata
318    pub metadata: HashMap<String, String>,
319}
320
321/// Safetensors file writer
322pub struct SafetensorsWriter {
323    /// Tensors to write
324    tensors: Vec<(String, TensorData)>,
325    /// Global metadata
326    metadata: HashMap<String, String>,
327}
328
329/// Tensor data for writing
330struct TensorData {
331    dtype: Dtype,
332    shape: Vec<usize>,
333    data: Vec<u8>,
334}
335
336/// Reference wrapper for TensorData that implements View
337struct TensorDataRef<'a>(&'a TensorData);
338
339impl View for TensorDataRef<'_> {
340    fn dtype(&self) -> Dtype {
341        self.0.dtype
342    }
343
344    fn shape(&self) -> &[usize] {
345        &self.0.shape
346    }
347
348    fn data(&self) -> std::borrow::Cow<'_, [u8]> {
349        std::borrow::Cow::Borrowed(&self.0.data)
350    }
351
352    fn data_len(&self) -> usize {
353        self.0.data.len()
354    }
355}
356
357impl SafetensorsWriter {
358    /// Create a new writer
359    pub fn new() -> Self {
360        Self {
361            tensors: Vec::new(),
362            metadata: HashMap::new(),
363        }
364    }
365
366    /// Add global metadata
367    pub fn with_metadata(mut self, key: String, value: String) -> Self {
368        self.metadata.insert(key, value);
369        self
370    }
371
372    /// Add a f32 tensor
373    pub fn add_f32(&mut self, name: &str, shape: Vec<usize>, data: &[f32]) {
374        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
375        self.tensors.push((
376            name.to_string(),
377            TensorData {
378                dtype: Dtype::F32,
379                shape,
380                data: bytes,
381            },
382        ));
383    }
384
385    /// Add a f64 tensor
386    pub fn add_f64(&mut self, name: &str, shape: Vec<usize>, data: &[f64]) {
387        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
388        self.tensors.push((
389            name.to_string(),
390            TensorData {
391                dtype: Dtype::F64,
392                shape,
393                data: bytes,
394            },
395        ));
396    }
397
398    /// Add an i32 tensor
399    pub fn add_i32(&mut self, name: &str, shape: Vec<usize>, data: &[i32]) {
400        let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
401        self.tensors.push((
402            name.to_string(),
403            TensorData {
404                dtype: Dtype::I32,
405                shape,
406                data: bytes,
407            },
408        ));
409    }
410
411    /// Add an i64 tensor
412    pub fn add_i64(&mut self, name: &str, shape: Vec<usize>, data: &[i64]) {
413        let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
414        self.tensors.push((
415            name.to_string(),
416            TensorData {
417                dtype: Dtype::I64,
418                shape,
419                data: bytes,
420            },
421        ));
422    }
423
424    /// Add an ArrowTensor
425    pub fn add_arrow_tensor(&mut self, tensor: &ArrowTensor) {
426        match tensor.metadata.dtype {
427            TensorDtype::Float32 => {
428                if let Some(data) = tensor.as_slice_f32() {
429                    self.add_f32(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
430                }
431            }
432            TensorDtype::Float64 => {
433                if let Some(data) = tensor.as_slice_f64() {
434                    self.add_f64(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
435                }
436            }
437            TensorDtype::Int32 => {
438                if let Some(data) = tensor.as_slice_i32() {
439                    self.add_i32(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
440                }
441            }
442            TensorDtype::Int64 => {
443                if let Some(data) = tensor.as_slice_i64() {
444                    self.add_i64(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
445                }
446            }
447            _ => {} // Other dtypes not yet supported
448        }
449    }
450
451    /// Write to a file
452    pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), SafetensorError> {
453        let bytes = self.serialize()?;
454        let mut file = File::create(path).map_err(SafetensorError::Io)?;
455        file.write_all(&bytes).map_err(SafetensorError::Io)?;
456        Ok(())
457    }
458
459    /// Serialize to bytes
460    pub fn serialize(&self) -> Result<Vec<u8>, SafetensorError> {
461        let tensors: Vec<(&str, TensorDataRef)> = self
462            .tensors
463            .iter()
464            .map(|(name, data)| (name.as_str(), TensorDataRef(data)))
465            .collect();
466
467        let metadata = if self.metadata.is_empty() {
468            None
469        } else {
470            let meta: HashMap<String, String> = self.metadata.clone();
471            Some(meta)
472        };
473
474        Ok(safetensors::tensor::serialize(
475            tensors.into_iter(),
476            metadata,
477        )?)
478    }
479}
480
481impl Default for SafetensorsWriter {
482    fn default() -> Self {
483        Self::new()
484    }
485}
486
487/// Chunked model storage for large models
488pub struct ChunkedModelStorage {
489    /// Base path for chunks
490    base_path: std::path::PathBuf,
491    /// Chunk size in bytes
492    chunk_size: usize,
493    /// Chunk index
494    chunks: Vec<ChunkInfo>,
495}
496
497/// Information about a model chunk
498#[derive(Debug, Clone, Serialize, Deserialize)]
499pub struct ChunkInfo {
500    /// Chunk index
501    pub index: usize,
502    /// Path to chunk file
503    pub path: String,
504    /// Tensors in this chunk
505    pub tensors: Vec<String>,
506    /// Size in bytes
507    pub size_bytes: usize,
508}
509
510impl ChunkedModelStorage {
511    /// Create a new chunked storage
512    pub fn new<P: AsRef<Path>>(base_path: P, chunk_size: usize) -> Self {
513        Self {
514            base_path: base_path.as_ref().to_path_buf(),
515            chunk_size,
516            chunks: Vec::new(),
517        }
518    }
519
520    /// Write a model in chunks
521    #[allow(clippy::too_many_arguments)]
522    pub fn write_chunked(&mut self, store: &ArrowTensorStore) -> Result<(), SafetensorError> {
523        let mut current_chunk = SafetensorsWriter::new();
524        let mut current_size = 0usize;
525        let mut current_tensors = Vec::new();
526
527        for name in store.names() {
528            if let Some(tensor) = store.get(name) {
529                let tensor_size = tensor.metadata.size_bytes();
530
531                // Start new chunk if current would exceed limit
532                if current_size + tensor_size > self.chunk_size && !current_tensors.is_empty() {
533                    self.write_chunk(current_chunk, &current_tensors, current_size)?;
534                    current_chunk = SafetensorsWriter::new();
535                    current_tensors = Vec::new();
536                    current_size = 0;
537                }
538
539                current_chunk.add_arrow_tensor(tensor);
540                current_tensors.push(name.to_string());
541                current_size += tensor_size;
542            }
543        }
544
545        // Write final chunk
546        if !current_tensors.is_empty() {
547            self.write_chunk(current_chunk, &current_tensors, current_size)?;
548        }
549
550        Ok(())
551    }
552
553    fn write_chunk(
554        &mut self,
555        writer: SafetensorsWriter,
556        tensors: &[String],
557        size: usize,
558    ) -> Result<(), SafetensorError> {
559        let index = self.chunks.len();
560        let filename = format!("chunk_{:04}.safetensors", index);
561        let path = self.base_path.join(&filename);
562
563        writer.write_to_file(&path)?;
564
565        self.chunks.push(ChunkInfo {
566            index,
567            path: filename,
568            tensors: tensors.to_vec(),
569            size_bytes: size,
570        });
571
572        Ok(())
573    }
574
575    /// Write chunk index
576    pub fn write_index(&self) -> Result<(), std::io::Error> {
577        let index_path = self.base_path.join("model_index.json");
578        let json = serde_json::to_string_pretty(&self.chunks)?;
579        std::fs::write(index_path, json)?;
580        Ok(())
581    }
582
583    /// Load chunk index
584    pub fn load_index<P: AsRef<Path>>(path: P) -> Result<Vec<ChunkInfo>, std::io::Error> {
585        let index_path = path.as_ref().join("model_index.json");
586        let content = std::fs::read_to_string(index_path)?;
587        let chunks: Vec<ChunkInfo> = serde_json::from_str(&content)?;
588        Ok(chunks)
589    }
590
591    /// Get chunk containing a specific tensor
592    pub fn find_tensor_chunk(&self, tensor_name: &str) -> Option<&ChunkInfo> {
593        self.chunks
594            .iter()
595            .find(|chunk| chunk.tensors.contains(&tensor_name.to_string()))
596    }
597}
598
599/// Custom error type for safetensor operations
600#[derive(Debug)]
601pub enum SafetensorError {
602    /// IO error
603    Io(std::io::Error),
604    /// Parse error
605    Parse(String),
606    /// Safetensors library error
607    Safetensors(SafeTensorError),
608}
609
610impl From<SafeTensorError> for SafetensorError {
611    fn from(err: SafeTensorError) -> Self {
612        SafetensorError::Safetensors(err)
613    }
614}
615
616impl std::fmt::Display for SafetensorError {
617    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
618        match self {
619            SafetensorError::Io(e) => write!(f, "IO error: {}", e),
620            SafetensorError::Parse(s) => write!(f, "Parse error: {}", s),
621            SafetensorError::Safetensors(e) => write!(f, "Safetensors error: {:?}", e),
622        }
623    }
624}
625
626impl std::error::Error for SafetensorError {}
627
628/// Convert safetensors dtype to our dtype
629fn convert_safetensor_dtype(dtype: Dtype) -> TensorDtype {
630    match dtype {
631        Dtype::F32 => TensorDtype::Float32,
632        Dtype::F64 => TensorDtype::Float64,
633        Dtype::I8 => TensorDtype::Int8,
634        Dtype::I16 => TensorDtype::Int16,
635        Dtype::I32 => TensorDtype::Int32,
636        Dtype::I64 => TensorDtype::Int64,
637        Dtype::U8 => TensorDtype::UInt8,
638        Dtype::U16 => TensorDtype::UInt16,
639        Dtype::U32 => TensorDtype::UInt32,
640        Dtype::U64 => TensorDtype::UInt64,
641        Dtype::BF16 => TensorDtype::BFloat16,
642        Dtype::F16 => TensorDtype::Float16,
643        _ => TensorDtype::Float32, // Default fallback
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use std::io::Write;
651    use tempfile::NamedTempFile;
652
653    #[test]
654    fn test_writer_and_reader() {
655        // Create a safetensors file
656        let mut writer =
657            SafetensorsWriter::new().with_metadata("format".to_string(), "test".to_string());
658
659        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
660        writer.add_f32("test_tensor", vec![3, 4], &data);
661
662        // Write to temp file
663        let mut temp_file = NamedTempFile::new().unwrap();
664        let bytes = writer.serialize().unwrap();
665        temp_file.write_all(&bytes).unwrap();
666        temp_file.flush().unwrap();
667
668        // Read back
669        let reader = SafetensorsReader::open(temp_file.path()).unwrap();
670
671        assert_eq!(reader.len(), 1);
672        assert!(reader.tensor_info("test_tensor").is_some());
673
674        let info = reader.tensor_info("test_tensor").unwrap();
675        assert_eq!(info.shape, vec![3, 4]);
676        assert_eq!(info.dtype, TensorDtype::Float32);
677
678        let loaded = reader.load_f32("test_tensor").unwrap();
679        assert_eq!(loaded, data);
680    }
681
682    #[test]
683    fn test_model_summary() {
684        let mut writer = SafetensorsWriter::new();
685        writer.add_f32("layer1", vec![10, 10], &[0.0; 100]);
686        writer.add_f32("layer2", vec![10, 5], &[0.0; 50]);
687
688        let bytes = writer.serialize().unwrap();
689        let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
690
691        let summary = reader.summary();
692        assert_eq!(summary.num_tensors, 2);
693        assert_eq!(summary.total_params, 150);
694    }
695
696    #[test]
697    fn test_arrow_conversion() {
698        let mut writer = SafetensorsWriter::new();
699        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
700        writer.add_f32("weights", vec![2, 3], &data);
701
702        let bytes = writer.serialize().unwrap();
703        let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
704
705        let tensor = reader.load_as_arrow("weights").unwrap();
706        assert_eq!(tensor.metadata.name, "weights");
707        assert_eq!(tensor.metadata.shape, vec![2, 3]);
708        assert_eq!(tensor.as_slice_f32().unwrap(), &data);
709    }
710
711    #[test]
712    fn test_f64_support() {
713        let mut writer = SafetensorsWriter::new();
714        let data: Vec<f64> = vec![1.5, 2.5, 3.5, 4.5];
715        writer.add_f64("weights_f64", vec![2, 2], &data);
716
717        let bytes = writer.serialize().unwrap();
718        let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
719
720        // Test load_f64
721        let loaded = reader.load_f64("weights_f64").unwrap();
722        assert_eq!(loaded, data);
723
724        // Test load_as_arrow
725        let tensor = reader.load_as_arrow("weights_f64").unwrap();
726        assert_eq!(tensor.metadata.name, "weights_f64");
727        assert_eq!(tensor.metadata.dtype, TensorDtype::Float64);
728        assert_eq!(tensor.as_slice_f64().unwrap(), &data);
729    }
730
731    #[test]
732    fn test_i32_support() {
733        let mut writer = SafetensorsWriter::new();
734        let data: Vec<i32> = vec![-10, 20, -30, 40, 50, -60];
735        writer.add_i32("indices", vec![2, 3], &data);
736
737        let bytes = writer.serialize().unwrap();
738        let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
739
740        // Test load_i32
741        let loaded = reader.load_i32("indices").unwrap();
742        assert_eq!(loaded, data);
743
744        // Test load_as_arrow
745        let tensor = reader.load_as_arrow("indices").unwrap();
746        assert_eq!(tensor.metadata.name, "indices");
747        assert_eq!(tensor.metadata.dtype, TensorDtype::Int32);
748        assert_eq!(tensor.as_slice_i32().unwrap(), &data);
749    }
750
751    #[test]
752    fn test_i64_support() {
753        let mut writer = SafetensorsWriter::new();
754        let data: Vec<i64> = vec![-1000000000, 2000000000, -3000000000, 4000000000];
755        writer.add_i64("large_indices", vec![2, 2], &data);
756
757        let bytes = writer.serialize().unwrap();
758        let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
759
760        // Test load_i64
761        let loaded = reader.load_i64("large_indices").unwrap();
762        assert_eq!(loaded, data);
763
764        // Test load_as_arrow
765        let tensor = reader.load_as_arrow("large_indices").unwrap();
766        assert_eq!(tensor.metadata.name, "large_indices");
767        assert_eq!(tensor.metadata.dtype, TensorDtype::Int64);
768        assert_eq!(tensor.as_slice_i64().unwrap(), &data);
769    }
770
771    #[test]
772    fn test_mixed_dtypes() {
773        let mut writer = SafetensorsWriter::new();
774
775        let f32_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
776        let f64_data: Vec<f64> = vec![5.5, 6.5];
777        let i32_data: Vec<i32> = vec![10, 20, 30];
778        let i64_data: Vec<i64> = vec![100, 200];
779
780        writer.add_f32("layer1", vec![4], &f32_data);
781        writer.add_f64("layer2", vec![2], &f64_data);
782        writer.add_i32("layer3", vec![3], &i32_data);
783        writer.add_i64("layer4", vec![2], &i64_data);
784
785        let bytes = writer.serialize().unwrap();
786        let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
787
788        assert_eq!(reader.len(), 4);
789
790        // Verify all tensors can be loaded correctly
791        assert_eq!(reader.load_f32("layer1").unwrap(), f32_data);
792        assert_eq!(reader.load_f64("layer2").unwrap(), f64_data);
793        assert_eq!(reader.load_i32("layer3").unwrap(), i32_data);
794        assert_eq!(reader.load_i64("layer4").unwrap(), i64_data);
795
796        // Verify all can be loaded as arrow
797        assert!(reader.load_as_arrow("layer1").is_some());
798        assert!(reader.load_as_arrow("layer2").is_some());
799        assert!(reader.load_as_arrow("layer3").is_some());
800        assert!(reader.load_as_arrow("layer4").is_some());
801    }
802
803    #[test]
804    fn test_arrow_tensor_roundtrip() {
805        use crate::arrow::ArrowTensor;
806
807        // Test f64
808        let f64_tensor = ArrowTensor::from_slice_f64("test_f64", vec![2, 2], &[1.0, 2.0, 3.0, 4.0]);
809        let mut writer = SafetensorsWriter::new();
810        writer.add_arrow_tensor(&f64_tensor);
811
812        let bytes = writer.serialize().unwrap();
813        let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
814        let loaded = reader.load_as_arrow("test_f64").unwrap();
815        assert_eq!(
816            loaded.as_slice_f64().unwrap(),
817            f64_tensor.as_slice_f64().unwrap()
818        );
819
820        // Test i32
821        let i32_tensor = ArrowTensor::from_slice_i32("test_i32", vec![3], &[10, 20, 30]);
822        let mut writer = SafetensorsWriter::new();
823        writer.add_arrow_tensor(&i32_tensor);
824
825        let bytes = writer.serialize().unwrap();
826        let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
827        let loaded = reader.load_as_arrow("test_i32").unwrap();
828        assert_eq!(
829            loaded.as_slice_i32().unwrap(),
830            i32_tensor.as_slice_i32().unwrap()
831        );
832
833        // Test i64
834        let i64_tensor = ArrowTensor::from_slice_i64("test_i64", vec![2], &[100, 200]);
835        let mut writer = SafetensorsWriter::new();
836        writer.add_arrow_tensor(&i64_tensor);
837
838        let bytes = writer.serialize().unwrap();
839        let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
840        let loaded = reader.load_as_arrow("test_i64").unwrap();
841        assert_eq!(
842            loaded.as_slice_i64().unwrap(),
843            i64_tensor.as_slice_i64().unwrap()
844        );
845    }
846}