candle_core/
safetensors.rs

1//! Module to load `safetensor` files into CPU/GPU memory.
2//!
3//! There are multiple ways to load tensors from safetensor files:
4//! - `load` function for loading directly into memory and returning a HashMap of tensors
5//! - `MmapedSafetensors` for memory mapping files and avoiding full allocation
6//! - `SliceSafetensors` for working with in-memory buffers
7//! - `BufferedSafetensors` for owning a buffer of data
8//!
9//! Tensors can also be serialized to safetensor format using the `save` function or
10//! `Tensor::save_safetensors` method.
11//!
12use crate::{DType, Device, Error, Result, Tensor, WithDType};
13use safetensors::tensor as st;
14use safetensors::tensor::SafeTensors;
15use std::borrow::Cow;
16use std::collections::HashMap;
17use std::path::Path;
18
19impl From<DType> for st::Dtype {
20    fn from(value: DType) -> Self {
21        match value {
22            DType::U8 => st::Dtype::U8,
23            DType::U32 => st::Dtype::U32,
24            DType::I64 => st::Dtype::I64,
25            DType::BF16 => st::Dtype::BF16,
26            DType::F16 => st::Dtype::F16,
27            DType::F32 => st::Dtype::F32,
28            DType::F64 => st::Dtype::F64,
29        }
30    }
31}
32
33impl TryFrom<st::Dtype> for DType {
34    type Error = Error;
35    fn try_from(value: st::Dtype) -> Result<Self> {
36        match value {
37            st::Dtype::U8 => Ok(DType::U8),
38            st::Dtype::U32 => Ok(DType::U32),
39            st::Dtype::I64 => Ok(DType::I64),
40            st::Dtype::BF16 => Ok(DType::BF16),
41            st::Dtype::F16 => Ok(DType::F16),
42            st::Dtype::F32 => Ok(DType::F32),
43            st::Dtype::F64 => Ok(DType::F64),
44            dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
45        }
46    }
47}
48
49impl st::View for Tensor {
50    fn dtype(&self) -> st::Dtype {
51        self.dtype().into()
52    }
53    fn shape(&self) -> &[usize] {
54        self.shape().dims()
55    }
56
57    fn data(&self) -> Cow<[u8]> {
58        // This copies data from GPU to CPU.
59        // TODO: Avoid the unwrap here.
60        Cow::Owned(convert_back(self).unwrap())
61    }
62
63    fn data_len(&self) -> usize {
64        let n: usize = self.shape().elem_count();
65        let bytes_per_element = self.dtype().size_in_bytes();
66        n * bytes_per_element
67    }
68}
69
70impl st::View for &Tensor {
71    fn dtype(&self) -> st::Dtype {
72        (*self).dtype().into()
73    }
74    fn shape(&self) -> &[usize] {
75        self.dims()
76    }
77
78    fn data(&self) -> Cow<[u8]> {
79        // This copies data from GPU to CPU.
80        // TODO: Avoid the unwrap here.
81        Cow::Owned(convert_back(self).unwrap())
82    }
83
84    fn data_len(&self) -> usize {
85        let n: usize = self.dims().iter().product();
86        let bytes_per_element = (*self).dtype().size_in_bytes();
87        n * bytes_per_element
88    }
89}
90
91impl Tensor {
92    pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
93        let data = [(name, self.clone())];
94        Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
95    }
96}
97
98fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
99    let size_in_bytes = T::DTYPE.size_in_bytes();
100    let elem_count = data.len() / size_in_bytes;
101    if (data.as_ptr() as usize) % size_in_bytes == 0 {
102        // SAFETY This is safe because we just checked that this
103        // was correctly aligned.
104        let data: &[T] =
105            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
106        Tensor::from_slice(data, shape, device)
107    } else {
108        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
109        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
110        let mut c: Vec<T> = Vec::with_capacity(elem_count);
111        // SAFETY: We just created c, so the allocated memory is necessarily
112        // contiguous and non overlapping with the view's data.
113        // We're downgrading the `c` pointer from T to u8, which removes alignment
114        // constraints.
115        unsafe {
116            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
117            c.set_len(elem_count)
118        }
119        Tensor::from_slice(&c, shape, device)
120    }
121}
122
123fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
124    data: &[u8],
125    shape: &[usize],
126    device: &Device,
127    conv: F,
128) -> Result<Tensor> {
129    let size_in_bytes = std::mem::size_of::<T>();
130    let elem_count = data.len() / size_in_bytes;
131    if (data.as_ptr() as usize) % size_in_bytes == 0 {
132        // SAFETY This is safe because we just checked that this
133        // was correctly aligned.
134        let data: &[T] =
135            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
136        let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
137        Tensor::from_vec(data, shape, device)
138    } else {
139        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
140        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
141        let mut c: Vec<T> = Vec::with_capacity(elem_count);
142        // SAFETY: We just created c, so the allocated memory is necessarily
143        // contiguous and non overlapping with the view's data.
144        // We're downgrading the `c` pointer from T to u8, which removes alignment
145        // constraints.
146        unsafe {
147            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
148            c.set_len(elem_count)
149        }
150        let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
151        Tensor::from_vec(c, shape, device)
152    }
153}
154
155fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
156    view: &st::TensorView<'_>,
157    device: &Device,
158    conv: F,
159) -> Result<Tensor> {
160    convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
161}
162
163fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
164    convert_slice::<T>(view.data(), view.shape(), device)
165}
166
167fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
168    let size_in_bytes = T::DTYPE.size_in_bytes();
169    let length = vs.len() * size_in_bytes;
170    let capacity = vs.capacity() * size_in_bytes;
171    let ptr = vs.as_mut_ptr() as *mut u8;
172    // Don't run the destructor for Vec<T>
173    std::mem::forget(vs);
174    // SAFETY:
175    //
176    // Every T is larger than u8, so there is no issue regarding alignment.
177    // This re-interpret the Vec<T> as a Vec<u8>.
178    unsafe { Vec::from_raw_parts(ptr, length, capacity) }
179}
180
181pub trait Load {
182    fn load(&self, device: &Device) -> Result<Tensor>;
183}
184
185impl Load for st::TensorView<'_> {
186    fn load(&self, device: &Device) -> Result<Tensor> {
187        convert(self, device)
188    }
189}
190
191impl Tensor {
192    pub fn from_raw_buffer(
193        data: &[u8],
194        dtype: DType,
195        shape: &[usize],
196        device: &Device,
197    ) -> Result<Self> {
198        match dtype {
199            DType::U8 => convert_slice::<u8>(data, shape, device),
200            DType::U32 => convert_slice::<u32>(data, shape, device),
201            DType::I64 => convert_slice::<i64>(data, shape, device),
202            DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
203            DType::F16 => convert_slice::<half::f16>(data, shape, device),
204            DType::F32 => convert_slice::<f32>(data, shape, device),
205            DType::F64 => convert_slice::<f64>(data, shape, device),
206        }
207    }
208}
209
210fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
211    match view.dtype() {
212        st::Dtype::U8 => convert_::<u8>(view, device),
213        st::Dtype::U16 => {
214            let conv = |x| Ok(u32::from(x));
215            convert_with_cast_::<u16, u32, _>(view, device, conv)
216        }
217        st::Dtype::U32 => convert_::<u32>(view, device),
218        st::Dtype::I32 => {
219            let conv = |x| Ok(i64::from(x));
220            convert_with_cast_::<i32, i64, _>(view, device, conv)
221        }
222        st::Dtype::I64 => convert_::<i64>(view, device),
223        st::Dtype::BF16 => convert_::<half::bf16>(view, device),
224        st::Dtype::F16 => convert_::<half::f16>(view, device),
225        st::Dtype::F32 => convert_::<f32>(view, device),
226        st::Dtype::F64 => convert_::<f64>(view, device),
227        dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
228    }
229}
230
231fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
232    // TODO: This makes an unnecessary copy when the tensor is on the cpu.
233    let tensor = tensor.flatten_all()?;
234    match tensor.dtype() {
235        DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
236        DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
237        DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
238        DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
239        DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
240        DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
241        DType::F64 => Ok(convert_back_::<f64>(tensor.to_vec1()?)),
242    }
243}
244
245pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
246    let data = std::fs::read(filename.as_ref())?;
247    load_buffer(&data[..], device)
248}
249
250pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
251    let st = safetensors::SafeTensors::deserialize(data)?;
252    st.tensors()
253        .into_iter()
254        .map(|(name, view)| Ok((name, view.load(device)?)))
255        .collect()
256}
257
258pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
259    tensors: &HashMap<K, Tensor>,
260    filename: P,
261) -> Result<()> {
262    Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
263}
264
265#[derive(yoke::Yokeable)]
266struct SafeTensors_<'a>(SafeTensors<'a>);
267
268pub struct MmapedSafetensors {
269    safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
270    routing: Option<HashMap<String, usize>>,
271}
272
273impl MmapedSafetensors {
274    /// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
275    ///
276    /// # Safety
277    ///
278    /// The unsafe is inherited from [`memmap2::MmapOptions`].
279    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
280        let p = p.as_ref();
281        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
282        let file = memmap2::MmapOptions::new()
283            .map(&file)
284            .map_err(|e| Error::from(e).with_path(p))?;
285        let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
286            file,
287            |data: &[u8]| {
288                let st = safetensors::SafeTensors::deserialize(data)
289                    .map_err(|e| Error::from(e).with_path(p))?;
290                Ok::<_, Error>(SafeTensors_(st))
291            },
292        )?;
293        Ok(Self {
294            safetensors: vec![safetensors],
295            routing: None,
296        })
297    }
298
299    /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
300    ///
301    /// If a tensor name appears in multiple files, the last entry is returned.
302    ///
303    /// # Safety
304    ///
305    /// The unsafe is inherited from [`memmap2::MmapOptions`].
306    pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
307        let mut routing = HashMap::new();
308        let mut safetensors = vec![];
309        for (index, p) in paths.iter().enumerate() {
310            let p = p.as_ref();
311            let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
312            let file = memmap2::MmapOptions::new()
313                .map(&file)
314                .map_err(|e| Error::from(e).with_path(p))?;
315            let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
316                file,
317                |data: &[u8]| {
318                    let st = safetensors::SafeTensors::deserialize(data)
319                        .map_err(|e| Error::from(e).with_path(p))?;
320                    Ok::<_, Error>(SafeTensors_(st))
321                },
322            )?;
323            for k in data.get().0.names() {
324                routing.insert(k.to_string(), index);
325            }
326            safetensors.push(data)
327        }
328        Ok(Self {
329            safetensors,
330            routing: Some(routing),
331        })
332    }
333
334    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
335        self.get(name)?.load(dev)
336    }
337
338    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
339        let mut tensors = vec![];
340        for safetensors in self.safetensors.iter() {
341            tensors.push(safetensors.get().0.tensors())
342        }
343        tensors.into_iter().flatten().collect()
344    }
345
346    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
347        let index = match &self.routing {
348            None => 0,
349            Some(routing) => {
350                let index = routing.get(name).ok_or_else(|| {
351                    Error::CannotFindTensor {
352                        path: name.to_string(),
353                    }
354                    .bt()
355                })?;
356                *index
357            }
358        };
359        Ok(self.safetensors[index].get().0.tensor(name)?)
360    }
361}
362
363pub struct SliceSafetensors<'a> {
364    safetensors: SafeTensors<'a>,
365}
366
367impl<'a> SliceSafetensors<'a> {
368    /// Creates a wrapper around a binary buffer and deserialize the safetensors header.
369    pub fn new(buffer: &'a [u8]) -> Result<Self> {
370        let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
371        Ok(Self { safetensors })
372    }
373
374    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
375        self.safetensors.tensor(name)?.load(dev)
376    }
377
378    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
379        self.safetensors.tensors()
380    }
381
382    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
383        Ok(self.safetensors.tensor(name)?)
384    }
385}
386
387pub struct BufferedSafetensors {
388    safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
389}
390
391impl BufferedSafetensors {
392    /// Creates a wrapper around a binary buffer and deserialize the safetensors header.
393    pub fn new(buffer: Vec<u8>) -> Result<Self> {
394        let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
395            buffer,
396            |data: &[u8]| {
397                let st = safetensors::SafeTensors::deserialize(data)?;
398                Ok::<_, Error>(SafeTensors_(st))
399            },
400        )?;
401        Ok(Self { safetensors })
402    }
403
404    pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
405        self.get(name)?.load(dev)
406    }
407
408    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
409        self.safetensors.get().0.tensors()
410    }
411
412    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
413        Ok(self.safetensors.get().0.tensor(name)?)
414    }
415}
416
417pub struct MmapedFile {
418    path: std::path::PathBuf,
419    inner: memmap2::Mmap,
420}
421
422impl MmapedFile {
423    /// Creates a wrapper around a memory mapped file from which you can retrieve
424    /// tensors using [`MmapedFile::deserialize`]
425    ///
426    /// # Safety
427    ///
428    /// The unsafe is inherited from [`memmap2::MmapOptions`].
429    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
430        let p = p.as_ref();
431        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
432        let inner = memmap2::MmapOptions::new()
433            .map(&file)
434            .map_err(|e| Error::from(e).with_path(p))?;
435        Ok(Self {
436            inner,
437            path: p.to_path_buf(),
438        })
439    }
440
441    pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
442        let st = safetensors::SafeTensors::deserialize(&self.inner)
443            .map_err(|e| Error::from(e).with_path(&self.path))?;
444        Ok(st)
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use std::collections::HashMap;
452
453    #[test]
454    fn save_single_tensor() {
455        let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
456        t.save_safetensors("t", "t.safetensors").unwrap();
457        let bytes = std::fs::read("t.safetensors").unwrap();
458        assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}       \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
459        std::fs::remove_file("t.safetensors").unwrap();
460    }
461
462    #[test]
463    fn save_load_multiple_tensors() {
464        let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
465        let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
466        let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
467        save(&map, "multi.safetensors").unwrap();
468
469        let weights = load("multi.safetensors", &Device::Cpu).unwrap();
470        assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
471        assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
472        let bytes = std::fs::read("multi.safetensors").unwrap();
473        assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}}      \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
474        std::fs::remove_file("multi.safetensors").unwrap();
475    }
476}