candle_core/
safetensors.rs

1//! Module to load `safetensor` files into CPU/GPU memory.
2//!
3//! There are multiple ways to load tensors from safetensor files:
4//! - `load` function for loading directly into memory and returning a HashMap of tensors
5//! - `MmapedSafetensors` for memory mapping files and avoiding full allocation
6//! - `SliceSafetensors` for working with in-memory buffers
7//! - `BufferedSafetensors` for owning a buffer of data
8//!
9//! Tensors can also be serialized to safetensor format using the `save` function or
10//! `Tensor::save_safetensors` method.
11//!
12use crate::op::BackpropOp;
13use crate::storage::Storage;
14use crate::tensor::from_storage;
15use crate::{DType, Device, Error, Result, Tensor, WithDType};
16use safetensors::tensor as st;
17use safetensors::tensor::SafeTensors;
18use std::borrow::Cow;
19use std::collections::HashMap;
20use std::path::Path;
21
22impl From<DType> for st::Dtype {
23    fn from(value: DType) -> Self {
24        match value {
25            DType::U8 => st::Dtype::U8,
26            DType::U32 => st::Dtype::U32,
27            DType::I16 => st::Dtype::I16,
28            DType::I32 => st::Dtype::I32,
29            DType::I64 => st::Dtype::I64,
30            DType::BF16 => st::Dtype::BF16,
31            DType::F16 => st::Dtype::F16,
32            DType::F32 => st::Dtype::F32,
33            DType::F64 => st::Dtype::F64,
34            DType::F8E4M3 => st::Dtype::F8_E4M3,
35            DType::F6E2M3 => st::Dtype::F6_E2M3,
36            DType::F6E3M2 => st::Dtype::F6_E3M2,
37            DType::F4 => st::Dtype::F4,
38            DType::F8E8M0 => st::Dtype::F8_E8M0,
39        }
40    }
41}
42
43impl TryFrom<st::Dtype> for DType {
44    type Error = Error;
45    fn try_from(value: st::Dtype) -> Result<Self> {
46        match value {
47            st::Dtype::U8 => Ok(DType::U8),
48            st::Dtype::U32 => Ok(DType::U32),
49            st::Dtype::I16 => Ok(DType::I16),
50            st::Dtype::I32 => Ok(DType::I32),
51            st::Dtype::I64 => Ok(DType::I64),
52            st::Dtype::BF16 => Ok(DType::BF16),
53            st::Dtype::F16 => Ok(DType::F16),
54            st::Dtype::F32 => Ok(DType::F32),
55            st::Dtype::F64 => Ok(DType::F64),
56            st::Dtype::F8_E4M3 => Ok(DType::F8E4M3),
57            st::Dtype::F6_E2M3 => Ok(DType::F6E2M3),
58            st::Dtype::F6_E3M2 => Ok(DType::F6E3M2),
59            st::Dtype::F4 => Ok(DType::F4),
60            st::Dtype::F8_E8M0 => Ok(DType::F8E8M0),
61            dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
62        }
63    }
64}
65
66impl st::View for Tensor {
67    fn dtype(&self) -> st::Dtype {
68        self.dtype().into()
69    }
70    fn shape(&self) -> &[usize] {
71        self.shape().dims()
72    }
73
74    fn data(&self) -> Cow<'_, [u8]> {
75        // This copies data from GPU to CPU.
76        // TODO: Avoid the unwrap here.
77        Cow::Owned(convert_back(self).unwrap())
78    }
79
80    fn data_len(&self) -> usize {
81        let n: usize = self.shape().elem_count();
82        let bytes_per_element = self.dtype().size_in_bytes();
83        n * bytes_per_element
84    }
85}
86
87impl st::View for &Tensor {
88    fn dtype(&self) -> st::Dtype {
89        (*self).dtype().into()
90    }
91    fn shape(&self) -> &[usize] {
92        self.dims()
93    }
94
95    fn data(&self) -> Cow<'_, [u8]> {
96        // This copies data from GPU to CPU.
97        // TODO: Avoid the unwrap here.
98        Cow::Owned(convert_back(self).unwrap())
99    }
100
101    fn data_len(&self) -> usize {
102        let n: usize = self.dims().iter().product();
103        let bytes_per_element = (*self).dtype().size_in_bytes();
104        n * bytes_per_element
105    }
106}
107
108impl Tensor {
109    pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
110        let data = [(name, self.clone())];
111        Ok(st::serialize_to_file(data, None, filename.as_ref())?)
112    }
113}
114
115fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
116    let size_in_bytes = T::DTYPE.size_in_bytes();
117    let elem_count = data.len() / size_in_bytes;
118    if (data.as_ptr() as usize).is_multiple_of(size_in_bytes) {
119        // SAFETY This is safe because we just checked that this
120        // was correctly aligned.
121        let data: &[T] =
122            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
123        Tensor::from_slice(data, shape, device)
124    } else {
125        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
126        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
127        let mut c: Vec<T> = Vec::with_capacity(elem_count);
128        // SAFETY: We just created c, so the allocated memory is necessarily
129        // contiguous and non overlapping with the view's data.
130        // We're downgrading the `c` pointer from T to u8, which removes alignment
131        // constraints.
132        unsafe {
133            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
134            c.set_len(elem_count)
135        }
136        Tensor::from_slice(&c, shape, device)
137    }
138}
139
140fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
141    data: &[u8],
142    shape: &[usize],
143    device: &Device,
144    conv: F,
145) -> Result<Tensor> {
146    let size_in_bytes = std::mem::size_of::<T>();
147    let elem_count = data.len() / size_in_bytes;
148    if (data.as_ptr() as usize).is_multiple_of(size_in_bytes) {
149        // SAFETY This is safe because we just checked that this
150        // was correctly aligned.
151        let data: &[T] =
152            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
153        let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
154        Tensor::from_vec(data, shape, device)
155    } else {
156        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
157        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
158        let mut c: Vec<T> = Vec::with_capacity(elem_count);
159        // SAFETY: We just created c, so the allocated memory is necessarily
160        // contiguous and non overlapping with the view's data.
161        // We're downgrading the `c` pointer from T to u8, which removes alignment
162        // constraints.
163        unsafe {
164            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
165            c.set_len(elem_count)
166        }
167        let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
168        Tensor::from_vec(c, shape, device)
169    }
170}
171
172fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
173    view: &st::TensorView<'_>,
174    device: &Device,
175    conv: F,
176) -> Result<Tensor> {
177    convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
178}
179
180fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
181    convert_slice::<T>(view.data(), view.shape(), device)
182}
183
184fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
185    let size_in_bytes = T::DTYPE.size_in_bytes();
186    let length = vs.len() * size_in_bytes;
187    let capacity = vs.capacity() * size_in_bytes;
188    let ptr = vs.as_mut_ptr() as *mut u8;
189    // Don't run the destructor for Vec<T>
190    std::mem::forget(vs);
191    // SAFETY:
192    //
193    // Every T is larger than u8, so there is no issue regarding alignment.
194    // This re-interpret the Vec<T> as a Vec<u8>.
195    unsafe { Vec::from_raw_parts(ptr, length, capacity) }
196}
197
198pub trait Load {
199    fn load(&self, device: &Device) -> Result<Tensor>;
200}
201
202impl Load for st::TensorView<'_> {
203    fn load(&self, device: &Device) -> Result<Tensor> {
204        convert(self, device)
205    }
206}
207
208impl Tensor {
209    pub fn from_raw_buffer(
210        data: &[u8],
211        dtype: DType,
212        shape: &[usize],
213        device: &Device,
214    ) -> Result<Self> {
215        match dtype {
216            DType::U8 => convert_slice::<u8>(data, shape, device),
217            DType::U32 => convert_slice::<u32>(data, shape, device),
218            DType::I16 => convert_slice::<i16>(data, shape, device),
219            DType::I32 => convert_slice::<i32>(data, shape, device),
220            DType::I64 => convert_slice::<i64>(data, shape, device),
221            DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
222            DType::F16 => convert_slice::<half::f16>(data, shape, device),
223            DType::F32 => convert_slice::<f32>(data, shape, device),
224            DType::F64 => convert_slice::<f64>(data, shape, device),
225            DType::F8E4M3 => convert_slice::<float8::F8E4M3>(data, shape, device),
226            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
227                // For dummy types, create storage with raw bytes
228                let storage = match device {
229                    Device::Cpu => {
230                        let cpu_storage = match dtype {
231                            DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
232                            DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
233                            DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()),
234                            DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
235                            _ => unreachable!(),
236                        };
237                        Storage::Cpu(cpu_storage)
238                    }
239                    #[cfg(feature = "cuda")]
240                    Device::Cuda(device) => {
241                        let mut slice = unsafe { device.alloc::<u8>(data.len())? };
242                        device.memcpy_htod(data, &mut slice)?;
243
244                        let slice = match dtype {
245                            DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice),
246                            DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice),
247                            DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice),
248                            DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice),
249                            _ => unreachable!(),
250                        };
251                        let storage = crate::cuda_backend::CudaStorage {
252                            slice,
253                            device: device.clone(),
254                        };
255                        Storage::Cuda(storage)
256                    }
257                    #[cfg(not(feature = "cuda"))]
258                    Device::Cuda(_) => {
259                        return Err(Error::Msg("CUDA support not compiled".to_string()));
260                    }
261                    #[cfg(feature = "metal")]
262                    Device::Metal(device) => {
263                        let buffer = device.new_buffer_with_data(data)?;
264
265                        let storage = crate::metal_backend::MetalStorage::new(
266                            buffer,
267                            device.clone(),
268                            data.len(),
269                            dtype,
270                        );
271                        Storage::Metal(storage)
272                    }
273                    #[cfg(not(feature = "metal"))]
274                    Device::Metal(_) => {
275                        return Err(Error::Msg("Metal support not compiled".to_string()));
276                    }
277                };
278
279                let op = BackpropOp::none();
280                Ok(from_storage(storage, shape, op, false))
281            }
282        }
283    }
284}
285
286fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
287    match view.dtype() {
288        st::Dtype::U8 => convert_::<u8>(view, device),
289        st::Dtype::U16 => {
290            let conv = |x| Ok(u32::from(x));
291            convert_with_cast_::<u16, u32, _>(view, device, conv)
292        }
293        st::Dtype::U32 => convert_::<u32>(view, device),
294        st::Dtype::I16 => convert_::<i16>(view, device),
295        st::Dtype::I32 => convert_::<i32>(view, device),
296        st::Dtype::I64 => convert_::<i64>(view, device),
297        st::Dtype::BF16 => convert_::<half::bf16>(view, device),
298        st::Dtype::F16 => convert_::<half::f16>(view, device),
299        st::Dtype::F32 => convert_::<f32>(view, device),
300        st::Dtype::F64 => convert_::<f64>(view, device),
301        st::Dtype::F8_E4M3 => convert_::<float8::F8E4M3>(view, device),
302        st::Dtype::F6_E2M3 | st::Dtype::F6_E3M2 | st::Dtype::F4 | st::Dtype::F8_E8M0 => {
303            // For dummy types, we need to handle loading by creating a dummy tensor
304            // Since these types don't have actual data representation, we'll create
305            // a tensor that indicates it's a dummy type
306            convert_dummy(view, device)
307        }
308        dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
309    }
310}
311
312fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
313    // For dummy types, we'll create the appropriate storage variant that preserves
314    // both the raw data and the correct dtype
315    let (dtype, _dtype_name) = match view.dtype() {
316        st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"),
317        st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"),
318        st::Dtype::F4 => (DType::F4, "F4 (MX4)"),
319        st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"),
320        _ => unreachable!("convert_dummy called with non-dummy dtype"),
321    };
322
323    // Load the raw bytes
324    let data = view.data();
325    let shape = view.shape();
326
327    // Create storage with the appropriate dummy type variant
328    let storage = match device {
329        Device::Cpu => {
330            let cpu_storage = match dtype {
331                DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
332                DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
333                DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()),
334                DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
335                _ => unreachable!(),
336            };
337            Storage::Cpu(cpu_storage)
338        }
339        #[cfg(feature = "cuda")]
340        Device::Cuda(device) => {
341            let mut slice = unsafe { device.alloc::<u8>(data.len())? };
342            device.memcpy_htod(data, &mut slice)?;
343
344            let slice = match dtype {
345                DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice),
346                DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice),
347                DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice),
348                DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice),
349                _ => unreachable!(),
350            };
351            let storage = crate::cuda_backend::CudaStorage {
352                slice,
353                device: device.clone(),
354            };
355            Storage::Cuda(storage)
356        }
357        #[cfg(not(feature = "cuda"))]
358        Device::Cuda(_) => {
359            return Err(Error::Msg("CUDA support not compiled".to_string()));
360        }
361        #[cfg(feature = "metal")]
362        Device::Metal(device) => {
363            let buffer = device.new_buffer_with_data(data)?;
364
365            let storage =
366                crate::metal_backend::MetalStorage::new(buffer, device.clone(), data.len(), dtype);
367            Storage::Metal(storage)
368        }
369        #[cfg(not(feature = "metal"))]
370        Device::Metal(_) => {
371            return Err(Error::Msg("Metal support not compiled".to_string()));
372        }
373    };
374
375    // Create tensor with correct dtype
376    let op = BackpropOp::none();
377    Ok(from_storage(storage, shape, op, false))
378}
379
380fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
381    // TODO: This makes an unnecessary copy when the tensor is on the cpu.
382    let tensor = tensor.flatten_all()?;
383    match tensor.dtype() {
384        DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
385        DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
386        DType::I16 => Ok(convert_back_::<i16>(tensor.to_vec1()?)),
387        DType::I32 => Ok(convert_back_::<i32>(tensor.to_vec1()?)),
388        DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
389        DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
390        DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
391        DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
392        DType::F64 => Ok(convert_back_::<f64>(tensor.to_vec1()?)),
393        DType::F8E4M3 => Ok(convert_back_::<float8::F8E4M3>(tensor.to_vec1()?)),
394        DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
395            Err(Error::Msg("Internal error: dtype mismatch in storage".to_string()).bt())
396        }
397    }
398}
399
400pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
401    let data = std::fs::read(filename.as_ref())?;
402    load_buffer(&data[..], device)
403}
404
405pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
406    let st = safetensors::SafeTensors::deserialize(data)?;
407    st.tensors()
408        .into_iter()
409        .map(|(name, view)| Ok((name, view.load(device)?)))
410        .collect()
411}
412
413pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
414    tensors: &HashMap<K, Tensor>,
415    filename: P,
416) -> Result<()> {
417    Ok(st::serialize_to_file(tensors, None, filename.as_ref())?)
418}
419
420#[derive(yoke::Yokeable)]
421struct SafeTensors_<'a>(SafeTensors<'a>);
422
423pub struct MmapedSafetensors {
424    safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
425    routing: Option<HashMap<String, usize>>,
426}
427
428impl MmapedSafetensors {
429    /// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
430    ///
431    /// # Safety
432    ///
433    /// The unsafe is inherited from [`memmap2::MmapOptions`].
434    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
435        let p = p.as_ref();
436        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
437        let file = memmap2::MmapOptions::new()
438            .map(&file)
439            .map_err(|e| Error::from(e).with_path(p))?;
440        let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
441            file,
442            |data: &[u8]| {
443                let st = safetensors::SafeTensors::deserialize(data)
444                    .map_err(|e| Error::from(e).with_path(p))?;
445                Ok::<_, Error>(SafeTensors_(st))
446            },
447        )?;
448        Ok(Self {
449            safetensors: vec![safetensors],
450            routing: None,
451        })
452    }
453
454    /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
455    ///
456    /// If a tensor name appears in multiple files, the last entry is returned.
457    ///
458    /// # Safety
459    ///
460    /// The unsafe is inherited from [`memmap2::MmapOptions`].
461    pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
462        let mut routing = HashMap::new();
463        let mut safetensors = vec![];
464        for (index, p) in paths.iter().enumerate() {
465            let p = p.as_ref();
466            let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
467            let file = memmap2::MmapOptions::new()
468                .map(&file)
469                .map_err(|e| Error::from(e).with_path(p))?;
470            let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
471                file,
472                |data: &[u8]| {
473                    let st = safetensors::SafeTensors::deserialize(data)
474                        .map_err(|e| Error::from(e).with_path(p))?;
475                    Ok::<_, Error>(SafeTensors_(st))
476                },
477            )?;
478            for k in data.get().0.names() {
479                routing.insert(k.to_string(), index);
480            }
481            safetensors.push(data)
482        }
483        Ok(Self {
484            safetensors,
485            routing: Some(routing),
486        })
487    }
488
489    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
490        self.get(name)?.load(dev)
491    }
492
493    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
494        let mut tensors = vec![];
495        for safetensors in self.safetensors.iter() {
496            tensors.push(safetensors.get().0.tensors())
497        }
498        tensors.into_iter().flatten().collect()
499    }
500
501    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
502        let index = match &self.routing {
503            None => 0,
504            Some(routing) => {
505                let index = routing.get(name).ok_or_else(|| {
506                    Error::CannotFindTensor {
507                        path: name.to_string(),
508                    }
509                    .bt()
510                })?;
511                *index
512            }
513        };
514        Ok(self.safetensors[index].get().0.tensor(name)?)
515    }
516}
517
518pub struct SliceSafetensors<'a> {
519    safetensors: SafeTensors<'a>,
520}
521
522impl<'a> SliceSafetensors<'a> {
523    /// Creates a wrapper around a binary buffer and deserialize the safetensors header.
524    pub fn new(buffer: &'a [u8]) -> Result<Self> {
525        let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
526        Ok(Self { safetensors })
527    }
528
529    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
530        self.safetensors.tensor(name)?.load(dev)
531    }
532
533    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
534        self.safetensors.tensors()
535    }
536
537    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
538        Ok(self.safetensors.tensor(name)?)
539    }
540}
541
542pub struct BufferedSafetensors {
543    safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
544}
545
546impl BufferedSafetensors {
547    /// Creates a wrapper around a binary buffer and deserialize the safetensors header.
548    pub fn new(buffer: Vec<u8>) -> Result<Self> {
549        let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
550            buffer,
551            |data: &[u8]| {
552                let st = safetensors::SafeTensors::deserialize(data)?;
553                Ok::<_, Error>(SafeTensors_(st))
554            },
555        )?;
556        Ok(Self { safetensors })
557    }
558
559    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
560        self.get(name)?.load(dev)
561    }
562
563    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
564        self.safetensors.get().0.tensors()
565    }
566
567    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
568        Ok(self.safetensors.get().0.tensor(name)?)
569    }
570}
571
572pub struct MmapedFile {
573    path: std::path::PathBuf,
574    inner: memmap2::Mmap,
575}
576
577impl MmapedFile {
578    /// Creates a wrapper around a memory mapped file from which you can retrieve
579    /// tensors using [`MmapedFile::deserialize`]
580    ///
581    /// # Safety
582    ///
583    /// The unsafe is inherited from [`memmap2::MmapOptions`].
584    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
585        let p = p.as_ref();
586        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
587        let inner = memmap2::MmapOptions::new()
588            .map(&file)
589            .map_err(|e| Error::from(e).with_path(p))?;
590        Ok(Self {
591            inner,
592            path: p.to_path_buf(),
593        })
594    }
595
596    pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
597        let st = safetensors::SafeTensors::deserialize(&self.inner)
598            .map_err(|e| Error::from(e).with_path(&self.path))?;
599        Ok(st)
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606    use std::collections::HashMap;
607
608    #[test]
609    fn save_single_tensor() {
610        let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
611        t.save_safetensors("t", "t.safetensors").unwrap();
612        let bytes = std::fs::read("t.safetensors").unwrap();
613        assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}       \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
614        std::fs::remove_file("t.safetensors").unwrap();
615    }
616
617    #[test]
618    fn save_load_multiple_tensors() {
619        let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
620        let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
621        let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
622        save(&map, "multi.safetensors").unwrap();
623
624        let weights = load("multi.safetensors", &Device::Cpu).unwrap();
625        assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
626        assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
627        let bytes = std::fs::read("multi.safetensors").unwrap();
628        assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}}      \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
629        std::fs::remove_file("multi.safetensors").unwrap();
630    }
631
632    #[test]
633    fn load_u8() {
634        let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"U8\",\"shape\":[2],\"data_offsets\":[0,2]}}   \x01\x03";
635        std::fs::write("test_u8.safetensors", bytes).unwrap();
636        let weights = load("test_u8.safetensors", &Device::Cpu).unwrap();
637        let tensor = weights.get("x").unwrap();
638        assert_eq!(tensor.dims(), &[2]);
639        assert_eq!(tensor.dtype(), DType::U8);
640        let data: Vec<u8> = tensor.to_vec1().unwrap();
641        assert_eq!(data, vec![1, 3]);
642        std::fs::remove_file("test_u8.safetensors").unwrap();
643    }
644}