1use crate::arrow::{ArrowTensor, ArrowTensorStore, TensorDtype};
21use bytes::Bytes;
22use memmap2::Mmap;
23use safetensors::tensor::{SafeTensorError, SafeTensors};
24use safetensors::{Dtype, View};
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::fs::File;
28use std::io::Write;
29use std::path::Path;
30
31pub struct SafetensorsReader {
33 mmap: Option<Mmap>,
35 bytes: Option<Bytes>,
37 metadata: HashMap<String, TensorInfo>,
39 global_metadata: HashMap<String, String>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct TensorInfo {
46 pub name: String,
48 pub dtype: TensorDtype,
50 pub shape: Vec<usize>,
52 pub data_offset: usize,
54 pub data_size: usize,
56}
57
58impl SafetensorsReader {
59 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, SafetensorError> {
61 let file = File::open(path.as_ref()).map_err(SafetensorError::Io)?;
62 let mmap = unsafe { Mmap::map(&file).map_err(SafetensorError::Io)? };
63
64 Self::from_mmap(mmap)
65 }
66
67 fn from_mmap(mmap: Mmap) -> Result<Self, SafetensorError> {
69 let tensors = SafeTensors::deserialize(&mmap)?;
71
72 let mut metadata = HashMap::new();
73 let global_metadata = HashMap::new();
74
75 for (name, view) in tensors.tensors() {
77 let dtype = convert_safetensor_dtype(view.dtype());
78 let shape = view.shape().to_vec();
79 let data = view.data();
80
81 let info = TensorInfo {
82 name: name.clone(),
83 dtype,
84 shape,
85 data_offset: data.as_ptr() as usize - mmap.as_ptr() as usize,
86 data_size: data.len(),
87 };
88 metadata.insert(name, info);
89 }
90
91 Ok(Self {
92 mmap: Some(mmap),
93 bytes: None,
94 metadata,
95 global_metadata,
96 })
97 }
98
99 pub fn from_bytes(bytes: Bytes) -> Result<Self, SafetensorError> {
101 let tensors = SafeTensors::deserialize(&bytes)?;
102
103 let mut metadata = HashMap::new();
104 let global_metadata = HashMap::new();
105
106 for (name, view) in tensors.tensors() {
107 let dtype = convert_safetensor_dtype(view.dtype());
108 let shape = view.shape().to_vec();
109 let data = view.data();
110
111 let info = TensorInfo {
112 name: name.clone(),
113 dtype,
114 shape,
115 data_offset: data.as_ptr() as usize - bytes.as_ptr() as usize,
116 data_size: data.len(),
117 };
118 metadata.insert(name, info);
119 }
120
121 Ok(Self {
122 mmap: None,
123 bytes: Some(bytes),
124 metadata,
125 global_metadata,
126 })
127 }
128
129 pub fn tensor_names(&self) -> Vec<&str> {
131 self.metadata.keys().map(|s| s.as_str()).collect()
132 }
133
134 pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
136 self.metadata.get(name)
137 }
138
139 pub fn global_metadata(&self) -> &HashMap<String, String> {
141 &self.global_metadata
142 }
143
144 pub fn len(&self) -> usize {
146 self.metadata.len()
147 }
148
149 pub fn is_empty(&self) -> bool {
151 self.metadata.is_empty()
152 }
153
154 pub fn tensor_data(&self, name: &str) -> Option<&[u8]> {
156 let info = self.metadata.get(name)?;
157 let data = self.get_data()?;
158 Some(&data[info.data_offset..info.data_offset + info.data_size])
159 }
160
161 fn get_data(&self) -> Option<&[u8]> {
163 if let Some(ref mmap) = self.mmap {
164 Some(mmap.as_ref())
165 } else if let Some(ref bytes) = self.bytes {
166 Some(bytes.as_ref())
167 } else {
168 None
169 }
170 }
171
172 pub fn load_f32(&self, name: &str) -> Option<Vec<f32>> {
174 let info = self.tensor_info(name)?;
175 if info.dtype != TensorDtype::Float32 {
176 return None;
177 }
178
179 let data = self.tensor_data(name)?;
180 let f32_data: Vec<f32> = data
181 .chunks_exact(4)
182 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
183 .collect();
184 Some(f32_data)
185 }
186
187 pub fn load_f64(&self, name: &str) -> Option<Vec<f64>> {
189 let info = self.tensor_info(name)?;
190 if info.dtype != TensorDtype::Float64 {
191 return None;
192 }
193
194 let data = self.tensor_data(name)?;
195 let f64_data: Vec<f64> = data
196 .chunks_exact(8)
197 .map(|chunk| {
198 f64::from_le_bytes([
199 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
200 ])
201 })
202 .collect();
203 Some(f64_data)
204 }
205
206 pub fn load_i32(&self, name: &str) -> Option<Vec<i32>> {
208 let info = self.tensor_info(name)?;
209 if info.dtype != TensorDtype::Int32 {
210 return None;
211 }
212
213 let data = self.tensor_data(name)?;
214 let i32_data: Vec<i32> = data
215 .chunks_exact(4)
216 .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
217 .collect();
218 Some(i32_data)
219 }
220
221 pub fn load_i64(&self, name: &str) -> Option<Vec<i64>> {
223 let info = self.tensor_info(name)?;
224 if info.dtype != TensorDtype::Int64 {
225 return None;
226 }
227
228 let data = self.tensor_data(name)?;
229 let i64_data: Vec<i64> = data
230 .chunks_exact(8)
231 .map(|chunk| {
232 i64::from_le_bytes([
233 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
234 ])
235 })
236 .collect();
237 Some(i64_data)
238 }
239
240 pub fn load_as_arrow(&self, name: &str) -> Option<ArrowTensor> {
242 let info = self.tensor_info(name)?;
243
244 match info.dtype {
245 TensorDtype::Float32 => {
246 let data = self.load_f32(name)?;
247 Some(ArrowTensor::from_slice_f32(name, info.shape.clone(), &data))
248 }
249 TensorDtype::Float64 => {
250 let data = self.load_f64(name)?;
251 Some(ArrowTensor::from_slice_f64(name, info.shape.clone(), &data))
252 }
253 TensorDtype::Int32 => {
254 let data = self.load_i32(name)?;
255 Some(ArrowTensor::from_slice_i32(name, info.shape.clone(), &data))
256 }
257 TensorDtype::Int64 => {
258 let data = self.load_i64(name)?;
259 Some(ArrowTensor::from_slice_i64(name, info.shape.clone(), &data))
260 }
261 _ => None, }
263 }
264
265 pub fn load_all_as_arrow(&self) -> ArrowTensorStore {
267 let mut store = ArrowTensorStore::new();
268
269 for name in self.tensor_names() {
270 if let Some(tensor) = self.load_as_arrow(name) {
271 store.insert(tensor);
272 }
273 }
274
275 store
276 }
277
278 pub fn total_size_bytes(&self) -> usize {
280 self.metadata.values().map(|info| info.data_size).sum()
281 }
282
283 pub fn summary(&self) -> ModelSummary {
285 let mut dtype_counts: HashMap<TensorDtype, usize> = HashMap::new();
286 let mut total_params = 0usize;
287 let mut total_bytes = 0usize;
288
289 for info in self.metadata.values() {
290 *dtype_counts.entry(info.dtype).or_insert(0) += 1;
291 let numel: usize = info.shape.iter().product();
292 total_params += numel;
293 total_bytes += info.data_size;
294 }
295
296 ModelSummary {
297 num_tensors: self.metadata.len(),
298 total_params,
299 total_bytes,
300 dtype_distribution: dtype_counts,
301 metadata: self.global_metadata.clone(),
302 }
303 }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct ModelSummary {
309 pub num_tensors: usize,
311 pub total_params: usize,
313 pub total_bytes: usize,
315 pub dtype_distribution: HashMap<TensorDtype, usize>,
317 pub metadata: HashMap<String, String>,
319}
320
321pub struct SafetensorsWriter {
323 tensors: Vec<(String, TensorData)>,
325 metadata: HashMap<String, String>,
327}
328
329struct TensorData {
331 dtype: Dtype,
332 shape: Vec<usize>,
333 data: Vec<u8>,
334}
335
336struct TensorDataRef<'a>(&'a TensorData);
338
339impl View for TensorDataRef<'_> {
340 fn dtype(&self) -> Dtype {
341 self.0.dtype
342 }
343
344 fn shape(&self) -> &[usize] {
345 &self.0.shape
346 }
347
348 fn data(&self) -> std::borrow::Cow<'_, [u8]> {
349 std::borrow::Cow::Borrowed(&self.0.data)
350 }
351
352 fn data_len(&self) -> usize {
353 self.0.data.len()
354 }
355}
356
357impl SafetensorsWriter {
358 pub fn new() -> Self {
360 Self {
361 tensors: Vec::new(),
362 metadata: HashMap::new(),
363 }
364 }
365
366 pub fn with_metadata(mut self, key: String, value: String) -> Self {
368 self.metadata.insert(key, value);
369 self
370 }
371
372 pub fn add_f32(&mut self, name: &str, shape: Vec<usize>, data: &[f32]) {
374 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
375 self.tensors.push((
376 name.to_string(),
377 TensorData {
378 dtype: Dtype::F32,
379 shape,
380 data: bytes,
381 },
382 ));
383 }
384
385 pub fn add_f64(&mut self, name: &str, shape: Vec<usize>, data: &[f64]) {
387 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
388 self.tensors.push((
389 name.to_string(),
390 TensorData {
391 dtype: Dtype::F64,
392 shape,
393 data: bytes,
394 },
395 ));
396 }
397
398 pub fn add_i32(&mut self, name: &str, shape: Vec<usize>, data: &[i32]) {
400 let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
401 self.tensors.push((
402 name.to_string(),
403 TensorData {
404 dtype: Dtype::I32,
405 shape,
406 data: bytes,
407 },
408 ));
409 }
410
411 pub fn add_i64(&mut self, name: &str, shape: Vec<usize>, data: &[i64]) {
413 let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
414 self.tensors.push((
415 name.to_string(),
416 TensorData {
417 dtype: Dtype::I64,
418 shape,
419 data: bytes,
420 },
421 ));
422 }
423
424 pub fn add_arrow_tensor(&mut self, tensor: &ArrowTensor) {
426 match tensor.metadata.dtype {
427 TensorDtype::Float32 => {
428 if let Some(data) = tensor.as_slice_f32() {
429 self.add_f32(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
430 }
431 }
432 TensorDtype::Float64 => {
433 if let Some(data) = tensor.as_slice_f64() {
434 self.add_f64(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
435 }
436 }
437 TensorDtype::Int32 => {
438 if let Some(data) = tensor.as_slice_i32() {
439 self.add_i32(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
440 }
441 }
442 TensorDtype::Int64 => {
443 if let Some(data) = tensor.as_slice_i64() {
444 self.add_i64(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
445 }
446 }
447 _ => {} }
449 }
450
451 pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), SafetensorError> {
453 let bytes = self.serialize()?;
454 let mut file = File::create(path).map_err(SafetensorError::Io)?;
455 file.write_all(&bytes).map_err(SafetensorError::Io)?;
456 Ok(())
457 }
458
459 pub fn serialize(&self) -> Result<Vec<u8>, SafetensorError> {
461 let tensors: Vec<(&str, TensorDataRef)> = self
462 .tensors
463 .iter()
464 .map(|(name, data)| (name.as_str(), TensorDataRef(data)))
465 .collect();
466
467 let metadata = if self.metadata.is_empty() {
468 None
469 } else {
470 let meta: HashMap<String, String> = self.metadata.clone();
471 Some(meta)
472 };
473
474 Ok(safetensors::tensor::serialize(
475 tensors.into_iter(),
476 metadata,
477 )?)
478 }
479}
480
481impl Default for SafetensorsWriter {
482 fn default() -> Self {
483 Self::new()
484 }
485}
486
487pub struct ChunkedModelStorage {
489 base_path: std::path::PathBuf,
491 chunk_size: usize,
493 chunks: Vec<ChunkInfo>,
495}
496
497#[derive(Debug, Clone, Serialize, Deserialize)]
499pub struct ChunkInfo {
500 pub index: usize,
502 pub path: String,
504 pub tensors: Vec<String>,
506 pub size_bytes: usize,
508}
509
510impl ChunkedModelStorage {
511 pub fn new<P: AsRef<Path>>(base_path: P, chunk_size: usize) -> Self {
513 Self {
514 base_path: base_path.as_ref().to_path_buf(),
515 chunk_size,
516 chunks: Vec::new(),
517 }
518 }
519
520 #[allow(clippy::too_many_arguments)]
522 pub fn write_chunked(&mut self, store: &ArrowTensorStore) -> Result<(), SafetensorError> {
523 let mut current_chunk = SafetensorsWriter::new();
524 let mut current_size = 0usize;
525 let mut current_tensors = Vec::new();
526
527 for name in store.names() {
528 if let Some(tensor) = store.get(name) {
529 let tensor_size = tensor.metadata.size_bytes();
530
531 if current_size + tensor_size > self.chunk_size && !current_tensors.is_empty() {
533 self.write_chunk(current_chunk, ¤t_tensors, current_size)?;
534 current_chunk = SafetensorsWriter::new();
535 current_tensors = Vec::new();
536 current_size = 0;
537 }
538
539 current_chunk.add_arrow_tensor(tensor);
540 current_tensors.push(name.to_string());
541 current_size += tensor_size;
542 }
543 }
544
545 if !current_tensors.is_empty() {
547 self.write_chunk(current_chunk, ¤t_tensors, current_size)?;
548 }
549
550 Ok(())
551 }
552
553 fn write_chunk(
554 &mut self,
555 writer: SafetensorsWriter,
556 tensors: &[String],
557 size: usize,
558 ) -> Result<(), SafetensorError> {
559 let index = self.chunks.len();
560 let filename = format!("chunk_{:04}.safetensors", index);
561 let path = self.base_path.join(&filename);
562
563 writer.write_to_file(&path)?;
564
565 self.chunks.push(ChunkInfo {
566 index,
567 path: filename,
568 tensors: tensors.to_vec(),
569 size_bytes: size,
570 });
571
572 Ok(())
573 }
574
575 pub fn write_index(&self) -> Result<(), std::io::Error> {
577 let index_path = self.base_path.join("model_index.json");
578 let json = serde_json::to_string_pretty(&self.chunks)?;
579 std::fs::write(index_path, json)?;
580 Ok(())
581 }
582
583 pub fn load_index<P: AsRef<Path>>(path: P) -> Result<Vec<ChunkInfo>, std::io::Error> {
585 let index_path = path.as_ref().join("model_index.json");
586 let content = std::fs::read_to_string(index_path)?;
587 let chunks: Vec<ChunkInfo> = serde_json::from_str(&content)?;
588 Ok(chunks)
589 }
590
591 pub fn find_tensor_chunk(&self, tensor_name: &str) -> Option<&ChunkInfo> {
593 self.chunks
594 .iter()
595 .find(|chunk| chunk.tensors.contains(&tensor_name.to_string()))
596 }
597}
598
599#[derive(Debug)]
601pub enum SafetensorError {
602 Io(std::io::Error),
604 Parse(String),
606 Safetensors(SafeTensorError),
608}
609
610impl From<SafeTensorError> for SafetensorError {
611 fn from(err: SafeTensorError) -> Self {
612 SafetensorError::Safetensors(err)
613 }
614}
615
616impl std::fmt::Display for SafetensorError {
617 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
618 match self {
619 SafetensorError::Io(e) => write!(f, "IO error: {}", e),
620 SafetensorError::Parse(s) => write!(f, "Parse error: {}", s),
621 SafetensorError::Safetensors(e) => write!(f, "Safetensors error: {:?}", e),
622 }
623 }
624}
625
626impl std::error::Error for SafetensorError {}
627
628fn convert_safetensor_dtype(dtype: Dtype) -> TensorDtype {
630 match dtype {
631 Dtype::F32 => TensorDtype::Float32,
632 Dtype::F64 => TensorDtype::Float64,
633 Dtype::I8 => TensorDtype::Int8,
634 Dtype::I16 => TensorDtype::Int16,
635 Dtype::I32 => TensorDtype::Int32,
636 Dtype::I64 => TensorDtype::Int64,
637 Dtype::U8 => TensorDtype::UInt8,
638 Dtype::U16 => TensorDtype::UInt16,
639 Dtype::U32 => TensorDtype::UInt32,
640 Dtype::U64 => TensorDtype::UInt64,
641 Dtype::BF16 => TensorDtype::BFloat16,
642 Dtype::F16 => TensorDtype::Float16,
643 _ => TensorDtype::Float32, }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use std::io::Write;
651 use tempfile::NamedTempFile;
652
653 #[test]
654 fn test_writer_and_reader() {
655 let mut writer =
657 SafetensorsWriter::new().with_metadata("format".to_string(), "test".to_string());
658
659 let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
660 writer.add_f32("test_tensor", vec![3, 4], &data);
661
662 let mut temp_file = NamedTempFile::new().unwrap();
664 let bytes = writer.serialize().unwrap();
665 temp_file.write_all(&bytes).unwrap();
666 temp_file.flush().unwrap();
667
668 let reader = SafetensorsReader::open(temp_file.path()).unwrap();
670
671 assert_eq!(reader.len(), 1);
672 assert!(reader.tensor_info("test_tensor").is_some());
673
674 let info = reader.tensor_info("test_tensor").unwrap();
675 assert_eq!(info.shape, vec![3, 4]);
676 assert_eq!(info.dtype, TensorDtype::Float32);
677
678 let loaded = reader.load_f32("test_tensor").unwrap();
679 assert_eq!(loaded, data);
680 }
681
682 #[test]
683 fn test_model_summary() {
684 let mut writer = SafetensorsWriter::new();
685 writer.add_f32("layer1", vec![10, 10], &[0.0; 100]);
686 writer.add_f32("layer2", vec![10, 5], &[0.0; 50]);
687
688 let bytes = writer.serialize().unwrap();
689 let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
690
691 let summary = reader.summary();
692 assert_eq!(summary.num_tensors, 2);
693 assert_eq!(summary.total_params, 150);
694 }
695
696 #[test]
697 fn test_arrow_conversion() {
698 let mut writer = SafetensorsWriter::new();
699 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
700 writer.add_f32("weights", vec![2, 3], &data);
701
702 let bytes = writer.serialize().unwrap();
703 let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
704
705 let tensor = reader.load_as_arrow("weights").unwrap();
706 assert_eq!(tensor.metadata.name, "weights");
707 assert_eq!(tensor.metadata.shape, vec![2, 3]);
708 assert_eq!(tensor.as_slice_f32().unwrap(), &data);
709 }
710
711 #[test]
712 fn test_f64_support() {
713 let mut writer = SafetensorsWriter::new();
714 let data: Vec<f64> = vec![1.5, 2.5, 3.5, 4.5];
715 writer.add_f64("weights_f64", vec![2, 2], &data);
716
717 let bytes = writer.serialize().unwrap();
718 let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
719
720 let loaded = reader.load_f64("weights_f64").unwrap();
722 assert_eq!(loaded, data);
723
724 let tensor = reader.load_as_arrow("weights_f64").unwrap();
726 assert_eq!(tensor.metadata.name, "weights_f64");
727 assert_eq!(tensor.metadata.dtype, TensorDtype::Float64);
728 assert_eq!(tensor.as_slice_f64().unwrap(), &data);
729 }
730
731 #[test]
732 fn test_i32_support() {
733 let mut writer = SafetensorsWriter::new();
734 let data: Vec<i32> = vec![-10, 20, -30, 40, 50, -60];
735 writer.add_i32("indices", vec![2, 3], &data);
736
737 let bytes = writer.serialize().unwrap();
738 let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
739
740 let loaded = reader.load_i32("indices").unwrap();
742 assert_eq!(loaded, data);
743
744 let tensor = reader.load_as_arrow("indices").unwrap();
746 assert_eq!(tensor.metadata.name, "indices");
747 assert_eq!(tensor.metadata.dtype, TensorDtype::Int32);
748 assert_eq!(tensor.as_slice_i32().unwrap(), &data);
749 }
750
751 #[test]
752 fn test_i64_support() {
753 let mut writer = SafetensorsWriter::new();
754 let data: Vec<i64> = vec![-1000000000, 2000000000, -3000000000, 4000000000];
755 writer.add_i64("large_indices", vec![2, 2], &data);
756
757 let bytes = writer.serialize().unwrap();
758 let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
759
760 let loaded = reader.load_i64("large_indices").unwrap();
762 assert_eq!(loaded, data);
763
764 let tensor = reader.load_as_arrow("large_indices").unwrap();
766 assert_eq!(tensor.metadata.name, "large_indices");
767 assert_eq!(tensor.metadata.dtype, TensorDtype::Int64);
768 assert_eq!(tensor.as_slice_i64().unwrap(), &data);
769 }
770
771 #[test]
772 fn test_mixed_dtypes() {
773 let mut writer = SafetensorsWriter::new();
774
775 let f32_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
776 let f64_data: Vec<f64> = vec![5.5, 6.5];
777 let i32_data: Vec<i32> = vec![10, 20, 30];
778 let i64_data: Vec<i64> = vec![100, 200];
779
780 writer.add_f32("layer1", vec![4], &f32_data);
781 writer.add_f64("layer2", vec![2], &f64_data);
782 writer.add_i32("layer3", vec![3], &i32_data);
783 writer.add_i64("layer4", vec![2], &i64_data);
784
785 let bytes = writer.serialize().unwrap();
786 let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
787
788 assert_eq!(reader.len(), 4);
789
790 assert_eq!(reader.load_f32("layer1").unwrap(), f32_data);
792 assert_eq!(reader.load_f64("layer2").unwrap(), f64_data);
793 assert_eq!(reader.load_i32("layer3").unwrap(), i32_data);
794 assert_eq!(reader.load_i64("layer4").unwrap(), i64_data);
795
796 assert!(reader.load_as_arrow("layer1").is_some());
798 assert!(reader.load_as_arrow("layer2").is_some());
799 assert!(reader.load_as_arrow("layer3").is_some());
800 assert!(reader.load_as_arrow("layer4").is_some());
801 }
802
803 #[test]
804 fn test_arrow_tensor_roundtrip() {
805 use crate::arrow::ArrowTensor;
806
807 let f64_tensor = ArrowTensor::from_slice_f64("test_f64", vec![2, 2], &[1.0, 2.0, 3.0, 4.0]);
809 let mut writer = SafetensorsWriter::new();
810 writer.add_arrow_tensor(&f64_tensor);
811
812 let bytes = writer.serialize().unwrap();
813 let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
814 let loaded = reader.load_as_arrow("test_f64").unwrap();
815 assert_eq!(
816 loaded.as_slice_f64().unwrap(),
817 f64_tensor.as_slice_f64().unwrap()
818 );
819
820 let i32_tensor = ArrowTensor::from_slice_i32("test_i32", vec![3], &[10, 20, 30]);
822 let mut writer = SafetensorsWriter::new();
823 writer.add_arrow_tensor(&i32_tensor);
824
825 let bytes = writer.serialize().unwrap();
826 let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
827 let loaded = reader.load_as_arrow("test_i32").unwrap();
828 assert_eq!(
829 loaded.as_slice_i32().unwrap(),
830 i32_tensor.as_slice_i32().unwrap()
831 );
832
833 let i64_tensor = ArrowTensor::from_slice_i64("test_i64", vec![2], &[100, 200]);
835 let mut writer = SafetensorsWriter::new();
836 writer.add_arrow_tensor(&i64_tensor);
837
838 let bytes = writer.serialize().unwrap();
839 let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
840 let loaded = reader.load_as_arrow("test_i64").unwrap();
841 assert_eq!(
842 loaded.as_slice_i64().unwrap(),
843 i64_tensor.as_slice_i64().unwrap()
844 );
845 }
846}