1use crate::{DType, Device, Error, Result, Tensor, WithDType};
13use safetensors::tensor as st;
14use safetensors::tensor::SafeTensors;
15use std::borrow::Cow;
16use std::collections::HashMap;
17use std::path::Path;
18
19impl From<DType> for st::Dtype {
20 fn from(value: DType) -> Self {
21 match value {
22 DType::U8 => st::Dtype::U8,
23 DType::U32 => st::Dtype::U32,
24 DType::I64 => st::Dtype::I64,
25 DType::BF16 => st::Dtype::BF16,
26 DType::F16 => st::Dtype::F16,
27 DType::F32 => st::Dtype::F32,
28 DType::F64 => st::Dtype::F64,
29 }
30 }
31}
32
33impl TryFrom<st::Dtype> for DType {
34 type Error = Error;
35 fn try_from(value: st::Dtype) -> Result<Self> {
36 match value {
37 st::Dtype::U8 => Ok(DType::U8),
38 st::Dtype::U32 => Ok(DType::U32),
39 st::Dtype::I64 => Ok(DType::I64),
40 st::Dtype::BF16 => Ok(DType::BF16),
41 st::Dtype::F16 => Ok(DType::F16),
42 st::Dtype::F32 => Ok(DType::F32),
43 st::Dtype::F64 => Ok(DType::F64),
44 dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
45 }
46 }
47}
48
49impl st::View for Tensor {
50 fn dtype(&self) -> st::Dtype {
51 self.dtype().into()
52 }
53 fn shape(&self) -> &[usize] {
54 self.shape().dims()
55 }
56
57 fn data(&self) -> Cow<[u8]> {
58 Cow::Owned(convert_back(self).unwrap())
61 }
62
63 fn data_len(&self) -> usize {
64 let n: usize = self.shape().elem_count();
65 let bytes_per_element = self.dtype().size_in_bytes();
66 n * bytes_per_element
67 }
68}
69
70impl st::View for &Tensor {
71 fn dtype(&self) -> st::Dtype {
72 (*self).dtype().into()
73 }
74 fn shape(&self) -> &[usize] {
75 self.dims()
76 }
77
78 fn data(&self) -> Cow<[u8]> {
79 Cow::Owned(convert_back(self).unwrap())
82 }
83
84 fn data_len(&self) -> usize {
85 let n: usize = self.dims().iter().product();
86 let bytes_per_element = (*self).dtype().size_in_bytes();
87 n * bytes_per_element
88 }
89}
90
91impl Tensor {
92 pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
93 let data = [(name, self.clone())];
94 Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
95 }
96}
97
98fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
99 let size_in_bytes = T::DTYPE.size_in_bytes();
100 let elem_count = data.len() / size_in_bytes;
101 if (data.as_ptr() as usize) % size_in_bytes == 0 {
102 let data: &[T] =
105 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
106 Tensor::from_slice(data, shape, device)
107 } else {
108 let mut c: Vec<T> = Vec::with_capacity(elem_count);
111 unsafe {
116 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
117 c.set_len(elem_count)
118 }
119 Tensor::from_slice(&c, shape, device)
120 }
121}
122
123fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
124 data: &[u8],
125 shape: &[usize],
126 device: &Device,
127 conv: F,
128) -> Result<Tensor> {
129 let size_in_bytes = std::mem::size_of::<T>();
130 let elem_count = data.len() / size_in_bytes;
131 if (data.as_ptr() as usize) % size_in_bytes == 0 {
132 let data: &[T] =
135 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
136 let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
137 Tensor::from_vec(data, shape, device)
138 } else {
139 let mut c: Vec<T> = Vec::with_capacity(elem_count);
142 unsafe {
147 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
148 c.set_len(elem_count)
149 }
150 let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
151 Tensor::from_vec(c, shape, device)
152 }
153}
154
155fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
156 view: &st::TensorView<'_>,
157 device: &Device,
158 conv: F,
159) -> Result<Tensor> {
160 convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
161}
162
163fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
164 convert_slice::<T>(view.data(), view.shape(), device)
165}
166
167fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
168 let size_in_bytes = T::DTYPE.size_in_bytes();
169 let length = vs.len() * size_in_bytes;
170 let capacity = vs.capacity() * size_in_bytes;
171 let ptr = vs.as_mut_ptr() as *mut u8;
172 std::mem::forget(vs);
174 unsafe { Vec::from_raw_parts(ptr, length, capacity) }
179}
180
181pub trait Load {
182 fn load(&self, device: &Device) -> Result<Tensor>;
183}
184
185impl Load for st::TensorView<'_> {
186 fn load(&self, device: &Device) -> Result<Tensor> {
187 convert(self, device)
188 }
189}
190
191impl Tensor {
192 pub fn from_raw_buffer(
193 data: &[u8],
194 dtype: DType,
195 shape: &[usize],
196 device: &Device,
197 ) -> Result<Self> {
198 match dtype {
199 DType::U8 => convert_slice::<u8>(data, shape, device),
200 DType::U32 => convert_slice::<u32>(data, shape, device),
201 DType::I64 => convert_slice::<i64>(data, shape, device),
202 DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
203 DType::F16 => convert_slice::<half::f16>(data, shape, device),
204 DType::F32 => convert_slice::<f32>(data, shape, device),
205 DType::F64 => convert_slice::<f64>(data, shape, device),
206 }
207 }
208}
209
210fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
211 match view.dtype() {
212 st::Dtype::U8 => convert_::<u8>(view, device),
213 st::Dtype::U16 => {
214 let conv = |x| Ok(u32::from(x));
215 convert_with_cast_::<u16, u32, _>(view, device, conv)
216 }
217 st::Dtype::U32 => convert_::<u32>(view, device),
218 st::Dtype::I32 => {
219 let conv = |x| Ok(i64::from(x));
220 convert_with_cast_::<i32, i64, _>(view, device, conv)
221 }
222 st::Dtype::I64 => convert_::<i64>(view, device),
223 st::Dtype::BF16 => convert_::<half::bf16>(view, device),
224 st::Dtype::F16 => convert_::<half::f16>(view, device),
225 st::Dtype::F32 => convert_::<f32>(view, device),
226 st::Dtype::F64 => convert_::<f64>(view, device),
227 dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
228 }
229}
230
231fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
232 let tensor = tensor.flatten_all()?;
234 match tensor.dtype() {
235 DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
236 DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
237 DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
238 DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
239 DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
240 DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
241 DType::F64 => Ok(convert_back_::<f64>(tensor.to_vec1()?)),
242 }
243}
244
245pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
246 let data = std::fs::read(filename.as_ref())?;
247 load_buffer(&data[..], device)
248}
249
250pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
251 let st = safetensors::SafeTensors::deserialize(data)?;
252 st.tensors()
253 .into_iter()
254 .map(|(name, view)| Ok((name, view.load(device)?)))
255 .collect()
256}
257
258pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
259 tensors: &HashMap<K, Tensor>,
260 filename: P,
261) -> Result<()> {
262 Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
263}
264
265#[derive(yoke::Yokeable)]
266struct SafeTensors_<'a>(SafeTensors<'a>);
267
268pub struct MmapedSafetensors {
269 safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
270 routing: Option<HashMap<String, usize>>,
271}
272
273impl MmapedSafetensors {
274 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
280 let p = p.as_ref();
281 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
282 let file = memmap2::MmapOptions::new()
283 .map(&file)
284 .map_err(|e| Error::from(e).with_path(p))?;
285 let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
286 file,
287 |data: &[u8]| {
288 let st = safetensors::SafeTensors::deserialize(data)
289 .map_err(|e| Error::from(e).with_path(p))?;
290 Ok::<_, Error>(SafeTensors_(st))
291 },
292 )?;
293 Ok(Self {
294 safetensors: vec![safetensors],
295 routing: None,
296 })
297 }
298
299 pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
307 let mut routing = HashMap::new();
308 let mut safetensors = vec![];
309 for (index, p) in paths.iter().enumerate() {
310 let p = p.as_ref();
311 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
312 let file = memmap2::MmapOptions::new()
313 .map(&file)
314 .map_err(|e| Error::from(e).with_path(p))?;
315 let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
316 file,
317 |data: &[u8]| {
318 let st = safetensors::SafeTensors::deserialize(data)
319 .map_err(|e| Error::from(e).with_path(p))?;
320 Ok::<_, Error>(SafeTensors_(st))
321 },
322 )?;
323 for k in data.get().0.names() {
324 routing.insert(k.to_string(), index);
325 }
326 safetensors.push(data)
327 }
328 Ok(Self {
329 safetensors,
330 routing: Some(routing),
331 })
332 }
333
334 pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
335 self.get(name)?.load(dev)
336 }
337
338 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
339 let mut tensors = vec![];
340 for safetensors in self.safetensors.iter() {
341 tensors.push(safetensors.get().0.tensors())
342 }
343 tensors.into_iter().flatten().collect()
344 }
345
346 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
347 let index = match &self.routing {
348 None => 0,
349 Some(routing) => {
350 let index = routing.get(name).ok_or_else(|| {
351 Error::CannotFindTensor {
352 path: name.to_string(),
353 }
354 .bt()
355 })?;
356 *index
357 }
358 };
359 Ok(self.safetensors[index].get().0.tensor(name)?)
360 }
361}
362
363pub struct SliceSafetensors<'a> {
364 safetensors: SafeTensors<'a>,
365}
366
367impl<'a> SliceSafetensors<'a> {
368 pub fn new(buffer: &'a [u8]) -> Result<Self> {
370 let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
371 Ok(Self { safetensors })
372 }
373
374 pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
375 self.safetensors.tensor(name)?.load(dev)
376 }
377
378 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
379 self.safetensors.tensors()
380 }
381
382 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
383 Ok(self.safetensors.tensor(name)?)
384 }
385}
386
387pub struct BufferedSafetensors {
388 safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
389}
390
391impl BufferedSafetensors {
392 pub fn new(buffer: Vec<u8>) -> Result<Self> {
394 let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
395 buffer,
396 |data: &[u8]| {
397 let st = safetensors::SafeTensors::deserialize(data)?;
398 Ok::<_, Error>(SafeTensors_(st))
399 },
400 )?;
401 Ok(Self { safetensors })
402 }
403
404 pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
405 self.get(name)?.load(dev)
406 }
407
408 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
409 self.safetensors.get().0.tensors()
410 }
411
412 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
413 Ok(self.safetensors.get().0.tensor(name)?)
414 }
415}
416
417pub struct MmapedFile {
418 path: std::path::PathBuf,
419 inner: memmap2::Mmap,
420}
421
422impl MmapedFile {
423 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
430 let p = p.as_ref();
431 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
432 let inner = memmap2::MmapOptions::new()
433 .map(&file)
434 .map_err(|e| Error::from(e).with_path(p))?;
435 Ok(Self {
436 inner,
437 path: p.to_path_buf(),
438 })
439 }
440
441 pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
442 let st = safetensors::SafeTensors::deserialize(&self.inner)
443 .map_err(|e| Error::from(e).with_path(&self.path))?;
444 Ok(st)
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use std::collections::HashMap;
452
453 #[test]
454 fn save_single_tensor() {
455 let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
456 t.save_safetensors("t", "t.safetensors").unwrap();
457 let bytes = std::fs::read("t.safetensors").unwrap();
458 assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
459 std::fs::remove_file("t.safetensors").unwrap();
460 }
461
462 #[test]
463 fn save_load_multiple_tensors() {
464 let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
465 let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
466 let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
467 save(&map, "multi.safetensors").unwrap();
468
469 let weights = load("multi.safetensors", &Device::Cpu).unwrap();
470 assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
471 assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
472 let bytes = std::fs::read("multi.safetensors").unwrap();
473 assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
474 std::fs::remove_file("multi.safetensors").unwrap();
475 }
476}