Skip to main content

hanzo_ml/
safetensors.rs

1//! Module to load `safetensor` files into CPU/GPU memory.
2//!
3//! There are multiple ways to load tensors from safetensor files:
4//! - `load` function for loading directly into memory and returning a HashMap of tensors
5//! - `MmapedSafetensors` for memory mapping files and avoiding full allocation
6//! - `SliceSafetensors` for working with in-memory buffers
7//! - `BufferedSafetensors` for owning a buffer of data
8//!
9//! Tensors can also be serialized to safetensor format using the `save` function or
10//! `Tensor::save_safetensors` method.
11//!
12use crate::op::BackpropOp;
13use crate::storage::Storage;
14use crate::tensor::from_storage;
15use crate::{DType, Device, Error, Result, Tensor, WithDType};
16use safetensors::tensor as st;
17use safetensors::tensor::SafeTensors;
18use std::borrow::Cow;
19use std::collections::HashMap;
20use std::path::Path;
21
22impl From<DType> for st::Dtype {
23    fn from(value: DType) -> Self {
24        match value {
25            DType::U8 => st::Dtype::U8,
26            DType::U32 => st::Dtype::U32,
27            DType::I16 => st::Dtype::I16,
28            DType::I32 => st::Dtype::I32,
29            DType::I64 => st::Dtype::I64,
30            DType::BF16 => st::Dtype::BF16,
31            DType::F16 => st::Dtype::F16,
32            DType::F32 => st::Dtype::F32,
33            DType::F64 => st::Dtype::F64,
34            DType::F8E4M3 => st::Dtype::F8_E4M3,
35            DType::F6E2M3 => st::Dtype::F6_E2M3,
36            DType::F6E3M2 => st::Dtype::F6_E3M2,
37            DType::F4 => st::Dtype::F4,
38            DType::F8E8M0 => st::Dtype::F8_E8M0,
39        }
40    }
41}
42
43impl TryFrom<st::Dtype> for DType {
44    type Error = Error;
45    fn try_from(value: st::Dtype) -> Result<Self> {
46        match value {
47            st::Dtype::U8 => Ok(DType::U8),
48            st::Dtype::U32 => Ok(DType::U32),
49            st::Dtype::I16 => Ok(DType::I16),
50            st::Dtype::I32 => Ok(DType::I32),
51            st::Dtype::I64 => Ok(DType::I64),
52            st::Dtype::BF16 => Ok(DType::BF16),
53            st::Dtype::F16 => Ok(DType::F16),
54            st::Dtype::F32 => Ok(DType::F32),
55            st::Dtype::F64 => Ok(DType::F64),
56            st::Dtype::F8_E4M3 => Ok(DType::F8E4M3),
57            st::Dtype::F6_E2M3 => Ok(DType::F6E2M3),
58            st::Dtype::F6_E3M2 => Ok(DType::F6E3M2),
59            st::Dtype::F4 => Ok(DType::F4),
60            st::Dtype::F8_E8M0 => Ok(DType::F8E8M0),
61            dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
62        }
63    }
64}
65
66impl st::View for Tensor {
67    fn dtype(&self) -> st::Dtype {
68        self.dtype().into()
69    }
70    fn shape(&self) -> &[usize] {
71        self.shape().dims()
72    }
73
74    fn data(&self) -> Cow<'_, [u8]> {
75        // This copies data from GPU to CPU.
76        // TODO: Avoid the unwrap here.
77        Cow::Owned(convert_back(self).unwrap())
78    }
79
80    fn data_len(&self) -> usize {
81        let n: usize = self.shape().elem_count();
82        let bytes_per_element = self.dtype().size_in_bytes();
83        n * bytes_per_element
84    }
85}
86
87impl st::View for &Tensor {
88    fn dtype(&self) -> st::Dtype {
89        (*self).dtype().into()
90    }
91    fn shape(&self) -> &[usize] {
92        self.dims()
93    }
94
95    fn data(&self) -> Cow<'_, [u8]> {
96        // This copies data from GPU to CPU.
97        // TODO: Avoid the unwrap here.
98        Cow::Owned(convert_back(self).unwrap())
99    }
100
101    fn data_len(&self) -> usize {
102        let n: usize = self.dims().iter().product();
103        let bytes_per_element = (*self).dtype().size_in_bytes();
104        n * bytes_per_element
105    }
106}
107
108impl Tensor {
109    pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
110        let data = [(name, self.clone())];
111        Ok(st::serialize_to_file(data, None, filename.as_ref())?)
112    }
113}
114
115fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
116    let size_in_bytes = T::DTYPE.size_in_bytes();
117    let elem_count = data.len() / size_in_bytes;
118    if (data.as_ptr() as usize).is_multiple_of(size_in_bytes) {
119        // SAFETY This is safe because we just checked that this
120        // was correctly aligned.
121        let data: &[T] =
122            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
123        Tensor::from_slice(data, shape, device)
124    } else {
125        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
126        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
127        let mut c: Vec<T> = Vec::with_capacity(elem_count);
128        // SAFETY: We just created c, so the allocated memory is necessarily
129        // contiguous and non overlapping with the view's data.
130        // We're downgrading the `c` pointer from T to u8, which removes alignment
131        // constraints.
132        unsafe {
133            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
134            c.set_len(elem_count)
135        }
136        Tensor::from_slice(&c, shape, device)
137    }
138}
139
140fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
141    data: &[u8],
142    shape: &[usize],
143    device: &Device,
144    conv: F,
145) -> Result<Tensor> {
146    let size_in_bytes = std::mem::size_of::<T>();
147    let elem_count = data.len() / size_in_bytes;
148    if (data.as_ptr() as usize).is_multiple_of(size_in_bytes) {
149        // SAFETY This is safe because we just checked that this
150        // was correctly aligned.
151        let data: &[T] =
152            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
153        let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
154        Tensor::from_vec(data, shape, device)
155    } else {
156        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
157        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
158        let mut c: Vec<T> = Vec::with_capacity(elem_count);
159        // SAFETY: We just created c, so the allocated memory is necessarily
160        // contiguous and non overlapping with the view's data.
161        // We're downgrading the `c` pointer from T to u8, which removes alignment
162        // constraints.
163        unsafe {
164            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
165            c.set_len(elem_count)
166        }
167        let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
168        Tensor::from_vec(c, shape, device)
169    }
170}
171
172fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
173    view: &st::TensorView<'_>,
174    device: &Device,
175    conv: F,
176) -> Result<Tensor> {
177    convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
178}
179
180fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
181    convert_slice::<T>(view.data(), view.shape(), device)
182}
183
184fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
185    let size_in_bytes = T::DTYPE.size_in_bytes();
186    let length = vs.len() * size_in_bytes;
187    let capacity = vs.capacity() * size_in_bytes;
188    let ptr = vs.as_mut_ptr() as *mut u8;
189    // Don't run the destructor for Vec<T>
190    std::mem::forget(vs);
191    // SAFETY:
192    //
193    // Every T is larger than u8, so there is no issue regarding alignment.
194    // This re-interpret the Vec<T> as a Vec<u8>.
195    unsafe { Vec::from_raw_parts(ptr, length, capacity) }
196}
197
198pub trait Load {
199    fn load(&self, device: &Device) -> Result<Tensor>;
200}
201
202impl Load for st::TensorView<'_> {
203    fn load(&self, device: &Device) -> Result<Tensor> {
204        convert(self, device)
205    }
206}
207
208impl Tensor {
209    pub fn from_raw_buffer(
210        data: &[u8],
211        dtype: DType,
212        shape: &[usize],
213        device: &Device,
214    ) -> Result<Self> {
215        match dtype {
216            DType::U8 => convert_slice::<u8>(data, shape, device),
217            DType::U32 => convert_slice::<u32>(data, shape, device),
218            DType::I16 => convert_slice::<i16>(data, shape, device),
219            DType::I32 => convert_slice::<i32>(data, shape, device),
220            DType::I64 => convert_slice::<i64>(data, shape, device),
221            DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
222            DType::F16 => convert_slice::<half::f16>(data, shape, device),
223            DType::F32 => convert_slice::<f32>(data, shape, device),
224            DType::F64 => convert_slice::<f64>(data, shape, device),
225            DType::F8E4M3 => convert_slice::<float8::F8E4M3>(data, shape, device),
226            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
227                // For dummy types, create storage with raw bytes
228                let storage = match device {
229                    Device::Cpu => {
230                        let cpu_storage = match dtype {
231                            DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
232                            DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
233                            DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()),
234                            DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
235                            _ => unreachable!(),
236                        };
237                        Storage::Cpu(cpu_storage)
238                    }
239                    #[cfg(feature = "cuda")]
240                    Device::Cuda(device) => {
241                        let mut slice = unsafe { device.alloc::<u8>(data.len())? };
242                        device.memcpy_htod(data, &mut slice)?;
243
244                        let slice = match dtype {
245                            DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice),
246                            DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice),
247                            DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice),
248                            DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice),
249                            _ => unreachable!(),
250                        };
251                        let storage = crate::cuda_backend::CudaStorage {
252                            slice,
253                            device: device.clone(),
254                        };
255                        Storage::Cuda(storage)
256                    }
257                    #[cfg(not(feature = "cuda"))]
258                    Device::Cuda(_) => {
259                        return Err(Error::Msg("CUDA support not compiled".to_string()));
260                    }
261                    #[cfg(feature = "metal")]
262                    Device::Metal(device) => {
263                        let buffer = device.new_buffer_with_data(data)?;
264
265                        let storage = crate::metal_backend::MetalStorage::new(
266                            buffer,
267                            device.clone(),
268                            data.len(),
269                            dtype,
270                        );
271                        Storage::Metal(storage)
272                    }
273                    #[cfg(feature = "rocm")]
274                    Device::Rocm(_) => crate::bail!("not supported on rocm yet"),
275                    #[cfg(feature = "vulkan")]
276                    Device::Vulkan(_) => crate::bail!("not supported on vulkan yet"),
277                    #[cfg(not(feature = "metal"))]
278                    Device::Metal(_) => {
279                        return Err(Error::Msg("Metal support not compiled".to_string()));
280                    }
281                };
282
283                let op = BackpropOp::none();
284                Ok(from_storage(storage, shape, op, false))
285            }
286        }
287    }
288}
289
290fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
291    match view.dtype() {
292        st::Dtype::U8 => convert_::<u8>(view, device),
293        st::Dtype::U16 => {
294            let conv = |x| Ok(u32::from(x));
295            convert_with_cast_::<u16, u32, _>(view, device, conv)
296        }
297        st::Dtype::U32 => convert_::<u32>(view, device),
298        st::Dtype::I16 => convert_::<i16>(view, device),
299        st::Dtype::I32 => convert_::<i32>(view, device),
300        st::Dtype::I64 => convert_::<i64>(view, device),
301        st::Dtype::BF16 => convert_::<half::bf16>(view, device),
302        st::Dtype::F16 => convert_::<half::f16>(view, device),
303        st::Dtype::F32 => convert_::<f32>(view, device),
304        st::Dtype::F64 => convert_::<f64>(view, device),
305        st::Dtype::F8_E4M3 => convert_::<float8::F8E4M3>(view, device),
306        st::Dtype::F6_E2M3 | st::Dtype::F6_E3M2 | st::Dtype::F4 | st::Dtype::F8_E8M0 => {
307            // For dummy types, we need to handle loading by creating a dummy tensor
308            // Since these types don't have actual data representation, we'll create
309            // a tensor that indicates it's a dummy type
310            convert_dummy(view, device)
311        }
312        dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
313    }
314}
315
316fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
317    // For dummy types, we'll create the appropriate storage variant that preserves
318    // both the raw data and the correct dtype
319    let (dtype, _dtype_name) = match view.dtype() {
320        st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"),
321        st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"),
322        st::Dtype::F4 => (DType::F4, "F4 (MX4)"),
323        st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"),
324        _ => unreachable!("convert_dummy called with non-dummy dtype"),
325    };
326
327    // Load the raw bytes
328    let data = view.data();
329    let shape = view.shape();
330
331    // Create storage with the appropriate dummy type variant
332    let storage = match device {
333        Device::Cpu => {
334            let cpu_storage = match dtype {
335                DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
336                DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
337                DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()),
338                DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
339                _ => unreachable!(),
340            };
341            Storage::Cpu(cpu_storage)
342        }
343        #[cfg(feature = "cuda")]
344        Device::Cuda(device) => {
345            let mut slice = unsafe { device.alloc::<u8>(data.len())? };
346            device.memcpy_htod(data, &mut slice)?;
347
348            let slice = match dtype {
349                DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice),
350                DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice),
351                DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice),
352                DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice),
353                _ => unreachable!(),
354            };
355            let storage = crate::cuda_backend::CudaStorage {
356                slice,
357                device: device.clone(),
358            };
359            Storage::Cuda(storage)
360        }
361        #[cfg(not(feature = "cuda"))]
362        Device::Cuda(_) => {
363            return Err(Error::Msg("CUDA support not compiled".to_string()));
364        }
365        #[cfg(feature = "metal")]
366        Device::Metal(device) => {
367            let buffer = device.new_buffer_with_data(data)?;
368
369            let storage =
370                crate::metal_backend::MetalStorage::new(buffer, device.clone(), data.len(), dtype);
371            Storage::Metal(storage)
372        }
373        #[cfg(feature = "rocm")]
374        Device::Rocm(_) => crate::bail!("not supported on rocm yet"),
375        #[cfg(feature = "vulkan")]
376        Device::Vulkan(_) => crate::bail!("not supported on vulkan yet"),
377        #[cfg(not(feature = "metal"))]
378        Device::Metal(_) => {
379            return Err(Error::Msg("Metal support not compiled".to_string()));
380        }
381    };
382
383    // Create tensor with correct dtype
384    let op = BackpropOp::none();
385    Ok(from_storage(storage, shape, op, false))
386}
387
388fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
389    // TODO: This makes an unnecessary copy when the tensor is on the cpu.
390    let tensor = tensor.flatten_all()?;
391    match tensor.dtype() {
392        DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
393        DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
394        DType::I16 => Ok(convert_back_::<i16>(tensor.to_vec1()?)),
395        DType::I32 => Ok(convert_back_::<i32>(tensor.to_vec1()?)),
396        DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
397        DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
398        DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
399        DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
400        DType::F64 => Ok(convert_back_::<f64>(tensor.to_vec1()?)),
401        DType::F8E4M3 => Ok(convert_back_::<float8::F8E4M3>(tensor.to_vec1()?)),
402        DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
403            Err(Error::Msg("Internal error: dtype mismatch in storage".to_string()).bt())
404        }
405    }
406}
407
408pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
409    let data = std::fs::read(filename.as_ref())?;
410    load_buffer(&data[..], device)
411}
412
413pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
414    let st = safetensors::SafeTensors::deserialize(data)?;
415    st.tensors()
416        .into_iter()
417        .map(|(name, view)| Ok((name, view.load(device)?)))
418        .collect()
419}
420
421pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
422    tensors: &HashMap<K, Tensor>,
423    filename: P,
424) -> Result<()> {
425    Ok(st::serialize_to_file(tensors, None, filename.as_ref())?)
426}
427
428#[derive(yoke::Yokeable)]
429struct SafeTensors_<'a>(SafeTensors<'a>);
430
431pub struct MmapedSafetensors {
432    safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
433    routing: Option<HashMap<String, usize>>,
434}
435
436impl MmapedSafetensors {
437    /// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
438    ///
439    /// # Safety
440    ///
441    /// The unsafe is inherited from [`memmap2::MmapOptions`].
442    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
443        let p = p.as_ref();
444        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
445        let file = memmap2::MmapOptions::new()
446            .map(&file)
447            .map_err(|e| Error::from(e).with_path(p))?;
448        let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
449            file,
450            |data: &[u8]| {
451                let st = safetensors::SafeTensors::deserialize(data)
452                    .map_err(|e| Error::from(e).with_path(p))?;
453                Ok::<_, Error>(SafeTensors_(st))
454            },
455        )?;
456        Ok(Self {
457            safetensors: vec![safetensors],
458            routing: None,
459        })
460    }
461
462    /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
463    ///
464    /// If a tensor name appears in multiple files, the last entry is returned.
465    ///
466    /// # Safety
467    ///
468    /// The unsafe is inherited from [`memmap2::MmapOptions`].
469    pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
470        let mut routing = HashMap::new();
471        let mut safetensors = vec![];
472        for (index, p) in paths.iter().enumerate() {
473            let p = p.as_ref();
474            let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
475            let file = memmap2::MmapOptions::new()
476                .map(&file)
477                .map_err(|e| Error::from(e).with_path(p))?;
478            let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
479                file,
480                |data: &[u8]| {
481                    let st = safetensors::SafeTensors::deserialize(data)
482                        .map_err(|e| Error::from(e).with_path(p))?;
483                    Ok::<_, Error>(SafeTensors_(st))
484                },
485            )?;
486            for k in data.get().0.names() {
487                routing.insert(k.to_string(), index);
488            }
489            safetensors.push(data)
490        }
491        Ok(Self {
492            safetensors,
493            routing: Some(routing),
494        })
495    }
496
497    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
498        self.get(name)?.load(dev)
499    }
500
501    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
502        let mut tensors = vec![];
503        for safetensors in self.safetensors.iter() {
504            tensors.push(safetensors.get().0.tensors())
505        }
506        tensors.into_iter().flatten().collect()
507    }
508
509    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
510        let index = match &self.routing {
511            None => 0,
512            Some(routing) => {
513                let index = routing.get(name).ok_or_else(|| {
514                    Error::CannotFindTensor {
515                        path: name.to_string(),
516                    }
517                    .bt()
518                })?;
519                *index
520            }
521        };
522        Ok(self.safetensors[index].get().0.tensor(name)?)
523    }
524}
525
526pub struct SliceSafetensors<'a> {
527    safetensors: SafeTensors<'a>,
528}
529
530impl<'a> SliceSafetensors<'a> {
531    /// Creates a wrapper around a binary buffer and deserialize the safetensors header.
532    pub fn new(buffer: &'a [u8]) -> Result<Self> {
533        let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
534        Ok(Self { safetensors })
535    }
536
537    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
538        self.safetensors.tensor(name)?.load(dev)
539    }
540
541    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
542        self.safetensors.tensors()
543    }
544
545    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
546        Ok(self.safetensors.tensor(name)?)
547    }
548}
549
550pub struct BufferedSafetensors {
551    safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
552}
553
554impl BufferedSafetensors {
555    /// Creates a wrapper around a binary buffer and deserialize the safetensors header.
556    pub fn new(buffer: Vec<u8>) -> Result<Self> {
557        let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
558            buffer,
559            |data: &[u8]| {
560                let st = safetensors::SafeTensors::deserialize(data)?;
561                Ok::<_, Error>(SafeTensors_(st))
562            },
563        )?;
564        Ok(Self { safetensors })
565    }
566
567    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
568        self.get(name)?.load(dev)
569    }
570
571    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
572        self.safetensors.get().0.tensors()
573    }
574
575    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
576        Ok(self.safetensors.get().0.tensor(name)?)
577    }
578}
579
580pub struct MmapedFile {
581    path: std::path::PathBuf,
582    inner: memmap2::Mmap,
583}
584
585impl MmapedFile {
586    /// Creates a wrapper around a memory mapped file from which you can retrieve
587    /// tensors using [`MmapedFile::deserialize`]
588    ///
589    /// # Safety
590    ///
591    /// The unsafe is inherited from [`memmap2::MmapOptions`].
592    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
593        let p = p.as_ref();
594        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
595        let inner = memmap2::MmapOptions::new()
596            .map(&file)
597            .map_err(|e| Error::from(e).with_path(p))?;
598        Ok(Self {
599            inner,
600            path: p.to_path_buf(),
601        })
602    }
603
604    pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
605        let st = safetensors::SafeTensors::deserialize(&self.inner)
606            .map_err(|e| Error::from(e).with_path(&self.path))?;
607        Ok(st)
608    }
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614    use std::collections::HashMap;
615
616    #[test]
617    fn save_single_tensor() {
618        let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
619        t.save_safetensors("t", "t.safetensors").unwrap();
620        let bytes = std::fs::read("t.safetensors").unwrap();
621        assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}       \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
622        std::fs::remove_file("t.safetensors").unwrap();
623    }
624
625    #[test]
626    fn save_load_multiple_tensors() {
627        let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
628        let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
629        let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
630        save(&map, "multi.safetensors").unwrap();
631
632        let weights = load("multi.safetensors", &Device::Cpu).unwrap();
633        assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
634        assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
635        let bytes = std::fs::read("multi.safetensors").unwrap();
636        assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}}      \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
637        std::fs::remove_file("multi.safetensors").unwrap();
638    }
639
640    #[test]
641    fn load_u8() {
642        let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"U8\",\"shape\":[2],\"data_offsets\":[0,2]}}   \x01\x03";
643        std::fs::write("test_u8.safetensors", bytes).unwrap();
644        let weights = load("test_u8.safetensors", &Device::Cpu).unwrap();
645        let tensor = weights.get("x").unwrap();
646        assert_eq!(tensor.dims(), &[2]);
647        assert_eq!(tensor.dtype(), DType::U8);
648        let data: Vec<u8> = tensor.to_vec1().unwrap();
649        assert_eq!(data, vec![1, 3]);
650        std::fs::remove_file("test_u8.safetensors").unwrap();
651    }
652}