1use crate::op::BackpropOp;
13use crate::storage::Storage;
14use crate::tensor::from_storage;
15use crate::{DType, Device, Error, Result, Tensor, WithDType};
16use safetensors::tensor as st;
17use safetensors::tensor::SafeTensors;
18use std::borrow::Cow;
19use std::collections::HashMap;
20use std::path::Path;
21
22impl From<DType> for st::Dtype {
23 fn from(value: DType) -> Self {
24 match value {
25 DType::U8 => st::Dtype::U8,
26 DType::U32 => st::Dtype::U32,
27 DType::I16 => st::Dtype::I16,
28 DType::I32 => st::Dtype::I32,
29 DType::I64 => st::Dtype::I64,
30 DType::BF16 => st::Dtype::BF16,
31 DType::F16 => st::Dtype::F16,
32 DType::F32 => st::Dtype::F32,
33 DType::F64 => st::Dtype::F64,
34 DType::F8E4M3 => st::Dtype::F8_E4M3,
35 DType::F6E2M3 => st::Dtype::F6_E2M3,
36 DType::F6E3M2 => st::Dtype::F6_E3M2,
37 DType::F4 => st::Dtype::F4,
38 DType::F8E8M0 => st::Dtype::F8_E8M0,
39 }
40 }
41}
42
43impl TryFrom<st::Dtype> for DType {
44 type Error = Error;
45 fn try_from(value: st::Dtype) -> Result<Self> {
46 match value {
47 st::Dtype::U8 => Ok(DType::U8),
48 st::Dtype::U32 => Ok(DType::U32),
49 st::Dtype::I16 => Ok(DType::I16),
50 st::Dtype::I32 => Ok(DType::I32),
51 st::Dtype::I64 => Ok(DType::I64),
52 st::Dtype::BF16 => Ok(DType::BF16),
53 st::Dtype::F16 => Ok(DType::F16),
54 st::Dtype::F32 => Ok(DType::F32),
55 st::Dtype::F64 => Ok(DType::F64),
56 st::Dtype::F8_E4M3 => Ok(DType::F8E4M3),
57 st::Dtype::F6_E2M3 => Ok(DType::F6E2M3),
58 st::Dtype::F6_E3M2 => Ok(DType::F6E3M2),
59 st::Dtype::F4 => Ok(DType::F4),
60 st::Dtype::F8_E8M0 => Ok(DType::F8E8M0),
61 dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
62 }
63 }
64}
65
66impl st::View for Tensor {
67 fn dtype(&self) -> st::Dtype {
68 self.dtype().into()
69 }
70 fn shape(&self) -> &[usize] {
71 self.shape().dims()
72 }
73
74 fn data(&self) -> Cow<'_, [u8]> {
75 Cow::Owned(convert_back(self).unwrap())
78 }
79
80 fn data_len(&self) -> usize {
81 let n: usize = self.shape().elem_count();
82 let bytes_per_element = self.dtype().size_in_bytes();
83 n * bytes_per_element
84 }
85}
86
87impl st::View for &Tensor {
88 fn dtype(&self) -> st::Dtype {
89 (*self).dtype().into()
90 }
91 fn shape(&self) -> &[usize] {
92 self.dims()
93 }
94
95 fn data(&self) -> Cow<'_, [u8]> {
96 Cow::Owned(convert_back(self).unwrap())
99 }
100
101 fn data_len(&self) -> usize {
102 let n: usize = self.dims().iter().product();
103 let bytes_per_element = (*self).dtype().size_in_bytes();
104 n * bytes_per_element
105 }
106}
107
108impl Tensor {
109 pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
110 let data = [(name, self.clone())];
111 Ok(st::serialize_to_file(data, None, filename.as_ref())?)
112 }
113}
114
115fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
116 let size_in_bytes = T::DTYPE.size_in_bytes();
117 let elem_count = data.len() / size_in_bytes;
118 if (data.as_ptr() as usize).is_multiple_of(size_in_bytes) {
119 let data: &[T] =
122 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
123 Tensor::from_slice(data, shape, device)
124 } else {
125 let mut c: Vec<T> = Vec::with_capacity(elem_count);
128 unsafe {
133 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
134 c.set_len(elem_count)
135 }
136 Tensor::from_slice(&c, shape, device)
137 }
138}
139
140fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
141 data: &[u8],
142 shape: &[usize],
143 device: &Device,
144 conv: F,
145) -> Result<Tensor> {
146 let size_in_bytes = std::mem::size_of::<T>();
147 let elem_count = data.len() / size_in_bytes;
148 if (data.as_ptr() as usize).is_multiple_of(size_in_bytes) {
149 let data: &[T] =
152 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
153 let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
154 Tensor::from_vec(data, shape, device)
155 } else {
156 let mut c: Vec<T> = Vec::with_capacity(elem_count);
159 unsafe {
164 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
165 c.set_len(elem_count)
166 }
167 let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
168 Tensor::from_vec(c, shape, device)
169 }
170}
171
172fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
173 view: &st::TensorView<'_>,
174 device: &Device,
175 conv: F,
176) -> Result<Tensor> {
177 convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
178}
179
180fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
181 convert_slice::<T>(view.data(), view.shape(), device)
182}
183
184fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
185 let size_in_bytes = T::DTYPE.size_in_bytes();
186 let length = vs.len() * size_in_bytes;
187 let capacity = vs.capacity() * size_in_bytes;
188 let ptr = vs.as_mut_ptr() as *mut u8;
189 std::mem::forget(vs);
191 unsafe { Vec::from_raw_parts(ptr, length, capacity) }
196}
197
198pub trait Load {
199 fn load(&self, device: &Device) -> Result<Tensor>;
200}
201
202impl Load for st::TensorView<'_> {
203 fn load(&self, device: &Device) -> Result<Tensor> {
204 convert(self, device)
205 }
206}
207
208impl Tensor {
209 pub fn from_raw_buffer(
210 data: &[u8],
211 dtype: DType,
212 shape: &[usize],
213 device: &Device,
214 ) -> Result<Self> {
215 match dtype {
216 DType::U8 => convert_slice::<u8>(data, shape, device),
217 DType::U32 => convert_slice::<u32>(data, shape, device),
218 DType::I16 => convert_slice::<i16>(data, shape, device),
219 DType::I32 => convert_slice::<i32>(data, shape, device),
220 DType::I64 => convert_slice::<i64>(data, shape, device),
221 DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
222 DType::F16 => convert_slice::<half::f16>(data, shape, device),
223 DType::F32 => convert_slice::<f32>(data, shape, device),
224 DType::F64 => convert_slice::<f64>(data, shape, device),
225 DType::F8E4M3 => convert_slice::<float8::F8E4M3>(data, shape, device),
226 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
227 let storage = match device {
229 Device::Cpu => {
230 let cpu_storage = match dtype {
231 DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
232 DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
233 DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()),
234 DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
235 _ => unreachable!(),
236 };
237 Storage::Cpu(cpu_storage)
238 }
239 #[cfg(feature = "cuda")]
240 Device::Cuda(device) => {
241 let mut slice = unsafe { device.alloc::<u8>(data.len())? };
242 device.memcpy_htod(data, &mut slice)?;
243
244 let slice = match dtype {
245 DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice),
246 DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice),
247 DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice),
248 DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice),
249 _ => unreachable!(),
250 };
251 let storage = crate::cuda_backend::CudaStorage {
252 slice,
253 device: device.clone(),
254 };
255 Storage::Cuda(storage)
256 }
257 #[cfg(not(feature = "cuda"))]
258 Device::Cuda(_) => {
259 return Err(Error::Msg("CUDA support not compiled".to_string()));
260 }
261 #[cfg(feature = "metal")]
262 Device::Metal(device) => {
263 let buffer = device.new_buffer_with_data(data)?;
264
265 let storage = crate::metal_backend::MetalStorage::new(
266 buffer,
267 device.clone(),
268 data.len(),
269 dtype,
270 );
271 Storage::Metal(storage)
272 }
273 #[cfg(feature = "rocm")]
274 Device::Rocm(_) => crate::bail!("not supported on rocm yet"),
275 #[cfg(feature = "vulkan")]
276 Device::Vulkan(_) => crate::bail!("not supported on vulkan yet"),
277 #[cfg(not(feature = "metal"))]
278 Device::Metal(_) => {
279 return Err(Error::Msg("Metal support not compiled".to_string()));
280 }
281 };
282
283 let op = BackpropOp::none();
284 Ok(from_storage(storage, shape, op, false))
285 }
286 }
287 }
288}
289
290fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
291 match view.dtype() {
292 st::Dtype::U8 => convert_::<u8>(view, device),
293 st::Dtype::U16 => {
294 let conv = |x| Ok(u32::from(x));
295 convert_with_cast_::<u16, u32, _>(view, device, conv)
296 }
297 st::Dtype::U32 => convert_::<u32>(view, device),
298 st::Dtype::I16 => convert_::<i16>(view, device),
299 st::Dtype::I32 => convert_::<i32>(view, device),
300 st::Dtype::I64 => convert_::<i64>(view, device),
301 st::Dtype::BF16 => convert_::<half::bf16>(view, device),
302 st::Dtype::F16 => convert_::<half::f16>(view, device),
303 st::Dtype::F32 => convert_::<f32>(view, device),
304 st::Dtype::F64 => convert_::<f64>(view, device),
305 st::Dtype::F8_E4M3 => convert_::<float8::F8E4M3>(view, device),
306 st::Dtype::F6_E2M3 | st::Dtype::F6_E3M2 | st::Dtype::F4 | st::Dtype::F8_E8M0 => {
307 convert_dummy(view, device)
311 }
312 dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
313 }
314}
315
316fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
317 let (dtype, _dtype_name) = match view.dtype() {
320 st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"),
321 st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"),
322 st::Dtype::F4 => (DType::F4, "F4 (MX4)"),
323 st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"),
324 _ => unreachable!("convert_dummy called with non-dummy dtype"),
325 };
326
327 let data = view.data();
329 let shape = view.shape();
330
331 let storage = match device {
333 Device::Cpu => {
334 let cpu_storage = match dtype {
335 DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
336 DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
337 DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()),
338 DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
339 _ => unreachable!(),
340 };
341 Storage::Cpu(cpu_storage)
342 }
343 #[cfg(feature = "cuda")]
344 Device::Cuda(device) => {
345 let mut slice = unsafe { device.alloc::<u8>(data.len())? };
346 device.memcpy_htod(data, &mut slice)?;
347
348 let slice = match dtype {
349 DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice),
350 DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice),
351 DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice),
352 DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice),
353 _ => unreachable!(),
354 };
355 let storage = crate::cuda_backend::CudaStorage {
356 slice,
357 device: device.clone(),
358 };
359 Storage::Cuda(storage)
360 }
361 #[cfg(not(feature = "cuda"))]
362 Device::Cuda(_) => {
363 return Err(Error::Msg("CUDA support not compiled".to_string()));
364 }
365 #[cfg(feature = "metal")]
366 Device::Metal(device) => {
367 let buffer = device.new_buffer_with_data(data)?;
368
369 let storage =
370 crate::metal_backend::MetalStorage::new(buffer, device.clone(), data.len(), dtype);
371 Storage::Metal(storage)
372 }
373 #[cfg(feature = "rocm")]
374 Device::Rocm(_) => crate::bail!("not supported on rocm yet"),
375 #[cfg(feature = "vulkan")]
376 Device::Vulkan(_) => crate::bail!("not supported on vulkan yet"),
377 #[cfg(not(feature = "metal"))]
378 Device::Metal(_) => {
379 return Err(Error::Msg("Metal support not compiled".to_string()));
380 }
381 };
382
383 let op = BackpropOp::none();
385 Ok(from_storage(storage, shape, op, false))
386}
387
388fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
389 let tensor = tensor.flatten_all()?;
391 match tensor.dtype() {
392 DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
393 DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
394 DType::I16 => Ok(convert_back_::<i16>(tensor.to_vec1()?)),
395 DType::I32 => Ok(convert_back_::<i32>(tensor.to_vec1()?)),
396 DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
397 DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
398 DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
399 DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
400 DType::F64 => Ok(convert_back_::<f64>(tensor.to_vec1()?)),
401 DType::F8E4M3 => Ok(convert_back_::<float8::F8E4M3>(tensor.to_vec1()?)),
402 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
403 Err(Error::Msg("Internal error: dtype mismatch in storage".to_string()).bt())
404 }
405 }
406}
407
408pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
409 let data = std::fs::read(filename.as_ref())?;
410 load_buffer(&data[..], device)
411}
412
413pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
414 let st = safetensors::SafeTensors::deserialize(data)?;
415 st.tensors()
416 .into_iter()
417 .map(|(name, view)| Ok((name, view.load(device)?)))
418 .collect()
419}
420
421pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
422 tensors: &HashMap<K, Tensor>,
423 filename: P,
424) -> Result<()> {
425 Ok(st::serialize_to_file(tensors, None, filename.as_ref())?)
426}
427
428#[derive(yoke::Yokeable)]
429struct SafeTensors_<'a>(SafeTensors<'a>);
430
431pub struct MmapedSafetensors {
432 safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
433 routing: Option<HashMap<String, usize>>,
434}
435
436impl MmapedSafetensors {
437 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
443 let p = p.as_ref();
444 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
445 let file = memmap2::MmapOptions::new()
446 .map(&file)
447 .map_err(|e| Error::from(e).with_path(p))?;
448 let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
449 file,
450 |data: &[u8]| {
451 let st = safetensors::SafeTensors::deserialize(data)
452 .map_err(|e| Error::from(e).with_path(p))?;
453 Ok::<_, Error>(SafeTensors_(st))
454 },
455 )?;
456 Ok(Self {
457 safetensors: vec![safetensors],
458 routing: None,
459 })
460 }
461
462 pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
470 let mut routing = HashMap::new();
471 let mut safetensors = vec![];
472 for (index, p) in paths.iter().enumerate() {
473 let p = p.as_ref();
474 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
475 let file = memmap2::MmapOptions::new()
476 .map(&file)
477 .map_err(|e| Error::from(e).with_path(p))?;
478 let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
479 file,
480 |data: &[u8]| {
481 let st = safetensors::SafeTensors::deserialize(data)
482 .map_err(|e| Error::from(e).with_path(p))?;
483 Ok::<_, Error>(SafeTensors_(st))
484 },
485 )?;
486 for k in data.get().0.names() {
487 routing.insert(k.to_string(), index);
488 }
489 safetensors.push(data)
490 }
491 Ok(Self {
492 safetensors,
493 routing: Some(routing),
494 })
495 }
496
497 pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
498 self.get(name)?.load(dev)
499 }
500
501 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
502 let mut tensors = vec![];
503 for safetensors in self.safetensors.iter() {
504 tensors.push(safetensors.get().0.tensors())
505 }
506 tensors.into_iter().flatten().collect()
507 }
508
509 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
510 let index = match &self.routing {
511 None => 0,
512 Some(routing) => {
513 let index = routing.get(name).ok_or_else(|| {
514 Error::CannotFindTensor {
515 path: name.to_string(),
516 }
517 .bt()
518 })?;
519 *index
520 }
521 };
522 Ok(self.safetensors[index].get().0.tensor(name)?)
523 }
524}
525
526pub struct SliceSafetensors<'a> {
527 safetensors: SafeTensors<'a>,
528}
529
530impl<'a> SliceSafetensors<'a> {
531 pub fn new(buffer: &'a [u8]) -> Result<Self> {
533 let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
534 Ok(Self { safetensors })
535 }
536
537 pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
538 self.safetensors.tensor(name)?.load(dev)
539 }
540
541 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
542 self.safetensors.tensors()
543 }
544
545 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
546 Ok(self.safetensors.tensor(name)?)
547 }
548}
549
550pub struct BufferedSafetensors {
551 safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
552}
553
554impl BufferedSafetensors {
555 pub fn new(buffer: Vec<u8>) -> Result<Self> {
557 let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
558 buffer,
559 |data: &[u8]| {
560 let st = safetensors::SafeTensors::deserialize(data)?;
561 Ok::<_, Error>(SafeTensors_(st))
562 },
563 )?;
564 Ok(Self { safetensors })
565 }
566
567 pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
568 self.get(name)?.load(dev)
569 }
570
571 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
572 self.safetensors.get().0.tensors()
573 }
574
575 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
576 Ok(self.safetensors.get().0.tensor(name)?)
577 }
578}
579
580pub struct MmapedFile {
581 path: std::path::PathBuf,
582 inner: memmap2::Mmap,
583}
584
585impl MmapedFile {
586 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
593 let p = p.as_ref();
594 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
595 let inner = memmap2::MmapOptions::new()
596 .map(&file)
597 .map_err(|e| Error::from(e).with_path(p))?;
598 Ok(Self {
599 inner,
600 path: p.to_path_buf(),
601 })
602 }
603
604 pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
605 let st = safetensors::SafeTensors::deserialize(&self.inner)
606 .map_err(|e| Error::from(e).with_path(&self.path))?;
607 Ok(st)
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use std::collections::HashMap;
615
616 #[test]
617 fn save_single_tensor() {
618 let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
619 t.save_safetensors("t", "t.safetensors").unwrap();
620 let bytes = std::fs::read("t.safetensors").unwrap();
621 assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
622 std::fs::remove_file("t.safetensors").unwrap();
623 }
624
625 #[test]
626 fn save_load_multiple_tensors() {
627 let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
628 let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
629 let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
630 save(&map, "multi.safetensors").unwrap();
631
632 let weights = load("multi.safetensors", &Device::Cpu).unwrap();
633 assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
634 assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
635 let bytes = std::fs::read("multi.safetensors").unwrap();
636 assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
637 std::fs::remove_file("multi.safetensors").unwrap();
638 }
639
640 #[test]
641 fn load_u8() {
642 let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"U8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03";
643 std::fs::write("test_u8.safetensors", bytes).unwrap();
644 let weights = load("test_u8.safetensors", &Device::Cpu).unwrap();
645 let tensor = weights.get("x").unwrap();
646 assert_eq!(tensor.dims(), &[2]);
647 assert_eq!(tensor.dtype(), DType::U8);
648 let data: Vec<u8> = tensor.to_vec1().unwrap();
649 assert_eq!(data, vec![1, 3]);
650 std::fs::remove_file("test_u8.safetensors").unwrap();
651 }
652}