1use crate::op::BackpropOp;
13use crate::storage::Storage;
14use crate::tensor::from_storage;
15use crate::{DType, Device, Error, Result, Tensor, WithDType};
16use safetensors::tensor as st;
17use safetensors::tensor::SafeTensors;
18use std::borrow::Cow;
19use std::collections::HashMap;
20use std::path::Path;
21
22impl From<DType> for st::Dtype {
23 fn from(value: DType) -> Self {
24 match value {
25 DType::U8 => st::Dtype::U8,
26 DType::U32 => st::Dtype::U32,
27 DType::I16 => st::Dtype::I16,
28 DType::I32 => st::Dtype::I32,
29 DType::I64 => st::Dtype::I64,
30 DType::BF16 => st::Dtype::BF16,
31 DType::F16 => st::Dtype::F16,
32 DType::F32 => st::Dtype::F32,
33 DType::F64 => st::Dtype::F64,
34 DType::F8E4M3 => st::Dtype::F8_E4M3,
35 DType::F6E2M3 => st::Dtype::F6_E2M3,
36 DType::F6E3M2 => st::Dtype::F6_E3M2,
37 DType::F4 => st::Dtype::F4,
38 DType::F8E8M0 => st::Dtype::F8_E8M0,
39 }
40 }
41}
42
43impl TryFrom<st::Dtype> for DType {
44 type Error = Error;
45 fn try_from(value: st::Dtype) -> Result<Self> {
46 match value {
47 st::Dtype::U8 => Ok(DType::U8),
48 st::Dtype::U32 => Ok(DType::U32),
49 st::Dtype::I16 => Ok(DType::I16),
50 st::Dtype::I32 => Ok(DType::I32),
51 st::Dtype::I64 => Ok(DType::I64),
52 st::Dtype::BF16 => Ok(DType::BF16),
53 st::Dtype::F16 => Ok(DType::F16),
54 st::Dtype::F32 => Ok(DType::F32),
55 st::Dtype::F64 => Ok(DType::F64),
56 st::Dtype::F8_E4M3 => Ok(DType::F8E4M3),
57 st::Dtype::F6_E2M3 => Ok(DType::F6E2M3),
58 st::Dtype::F6_E3M2 => Ok(DType::F6E3M2),
59 st::Dtype::F4 => Ok(DType::F4),
60 st::Dtype::F8_E8M0 => Ok(DType::F8E8M0),
61 dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
62 }
63 }
64}
65
66impl st::View for Tensor {
67 fn dtype(&self) -> st::Dtype {
68 self.dtype().into()
69 }
70 fn shape(&self) -> &[usize] {
71 self.shape().dims()
72 }
73
74 fn data(&self) -> Cow<'_, [u8]> {
75 Cow::Owned(convert_back(self).unwrap())
78 }
79
80 fn data_len(&self) -> usize {
81 let n: usize = self.shape().elem_count();
82 let bytes_per_element = self.dtype().size_in_bytes();
83 n * bytes_per_element
84 }
85}
86
87impl st::View for &Tensor {
88 fn dtype(&self) -> st::Dtype {
89 (*self).dtype().into()
90 }
91 fn shape(&self) -> &[usize] {
92 self.dims()
93 }
94
95 fn data(&self) -> Cow<'_, [u8]> {
96 Cow::Owned(convert_back(self).unwrap())
99 }
100
101 fn data_len(&self) -> usize {
102 let n: usize = self.dims().iter().product();
103 let bytes_per_element = (*self).dtype().size_in_bytes();
104 n * bytes_per_element
105 }
106}
107
108impl Tensor {
109 pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
110 let data = [(name, self.clone())];
111 Ok(st::serialize_to_file(data, None, filename.as_ref())?)
112 }
113}
114
115fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
116 let size_in_bytes = T::DTYPE.size_in_bytes();
117 let elem_count = data.len() / size_in_bytes;
118 if (data.as_ptr() as usize).is_multiple_of(size_in_bytes) {
119 let data: &[T] =
122 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
123 Tensor::from_slice(data, shape, device)
124 } else {
125 let mut c: Vec<T> = Vec::with_capacity(elem_count);
128 unsafe {
133 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
134 c.set_len(elem_count)
135 }
136 Tensor::from_slice(&c, shape, device)
137 }
138}
139
140fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
141 data: &[u8],
142 shape: &[usize],
143 device: &Device,
144 conv: F,
145) -> Result<Tensor> {
146 let size_in_bytes = std::mem::size_of::<T>();
147 let elem_count = data.len() / size_in_bytes;
148 if (data.as_ptr() as usize).is_multiple_of(size_in_bytes) {
149 let data: &[T] =
152 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
153 let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
154 Tensor::from_vec(data, shape, device)
155 } else {
156 let mut c: Vec<T> = Vec::with_capacity(elem_count);
159 unsafe {
164 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
165 c.set_len(elem_count)
166 }
167 let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
168 Tensor::from_vec(c, shape, device)
169 }
170}
171
172fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
173 view: &st::TensorView<'_>,
174 device: &Device,
175 conv: F,
176) -> Result<Tensor> {
177 convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
178}
179
180fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
181 convert_slice::<T>(view.data(), view.shape(), device)
182}
183
184fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
185 let size_in_bytes = T::DTYPE.size_in_bytes();
186 let length = vs.len() * size_in_bytes;
187 let capacity = vs.capacity() * size_in_bytes;
188 let ptr = vs.as_mut_ptr() as *mut u8;
189 std::mem::forget(vs);
191 unsafe { Vec::from_raw_parts(ptr, length, capacity) }
196}
197
198pub trait Load {
199 fn load(&self, device: &Device) -> Result<Tensor>;
200}
201
202impl Load for st::TensorView<'_> {
203 fn load(&self, device: &Device) -> Result<Tensor> {
204 convert(self, device)
205 }
206}
207
208impl Tensor {
209 pub fn from_raw_buffer(
210 data: &[u8],
211 dtype: DType,
212 shape: &[usize],
213 device: &Device,
214 ) -> Result<Self> {
215 match dtype {
216 DType::U8 => convert_slice::<u8>(data, shape, device),
217 DType::U32 => convert_slice::<u32>(data, shape, device),
218 DType::I16 => convert_slice::<i16>(data, shape, device),
219 DType::I32 => convert_slice::<i32>(data, shape, device),
220 DType::I64 => convert_slice::<i64>(data, shape, device),
221 DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
222 DType::F16 => convert_slice::<half::f16>(data, shape, device),
223 DType::F32 => convert_slice::<f32>(data, shape, device),
224 DType::F64 => convert_slice::<f64>(data, shape, device),
225 DType::F8E4M3 => convert_slice::<float8::F8E4M3>(data, shape, device),
226 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
227 let storage = match device {
229 Device::Cpu => {
230 let cpu_storage = match dtype {
231 DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
232 DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
233 DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()),
234 DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
235 _ => unreachable!(),
236 };
237 Storage::Cpu(cpu_storage)
238 }
239 #[cfg(feature = "cuda")]
240 Device::Cuda(device) => {
241 let mut slice = unsafe { device.alloc::<u8>(data.len())? };
242 device.memcpy_htod(data, &mut slice)?;
243
244 let slice = match dtype {
245 DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice),
246 DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice),
247 DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice),
248 DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice),
249 _ => unreachable!(),
250 };
251 let storage = crate::cuda_backend::CudaStorage {
252 slice,
253 device: device.clone(),
254 };
255 Storage::Cuda(storage)
256 }
257 #[cfg(not(feature = "cuda"))]
258 Device::Cuda(_) => {
259 return Err(Error::Msg("CUDA support not compiled".to_string()));
260 }
261 #[cfg(feature = "metal")]
262 Device::Metal(device) => {
263 let buffer = device.new_buffer_with_data(data)?;
264
265 let storage = crate::metal_backend::MetalStorage::new(
266 buffer,
267 device.clone(),
268 data.len(),
269 dtype,
270 );
271 Storage::Metal(storage)
272 }
273 #[cfg(not(feature = "metal"))]
274 Device::Metal(_) => {
275 return Err(Error::Msg("Metal support not compiled".to_string()));
276 }
277 };
278
279 let op = BackpropOp::none();
280 Ok(from_storage(storage, shape, op, false))
281 }
282 }
283 }
284}
285
286fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
287 match view.dtype() {
288 st::Dtype::U8 => convert_::<u8>(view, device),
289 st::Dtype::U16 => {
290 let conv = |x| Ok(u32::from(x));
291 convert_with_cast_::<u16, u32, _>(view, device, conv)
292 }
293 st::Dtype::U32 => convert_::<u32>(view, device),
294 st::Dtype::I16 => convert_::<i16>(view, device),
295 st::Dtype::I32 => convert_::<i32>(view, device),
296 st::Dtype::I64 => convert_::<i64>(view, device),
297 st::Dtype::BF16 => convert_::<half::bf16>(view, device),
298 st::Dtype::F16 => convert_::<half::f16>(view, device),
299 st::Dtype::F32 => convert_::<f32>(view, device),
300 st::Dtype::F64 => convert_::<f64>(view, device),
301 st::Dtype::F8_E4M3 => convert_::<float8::F8E4M3>(view, device),
302 st::Dtype::F6_E2M3 | st::Dtype::F6_E3M2 | st::Dtype::F4 | st::Dtype::F8_E8M0 => {
303 convert_dummy(view, device)
307 }
308 dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
309 }
310}
311
312fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
313 let (dtype, _dtype_name) = match view.dtype() {
316 st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"),
317 st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"),
318 st::Dtype::F4 => (DType::F4, "F4 (MX4)"),
319 st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"),
320 _ => unreachable!("convert_dummy called with non-dummy dtype"),
321 };
322
323 let data = view.data();
325 let shape = view.shape();
326
327 let storage = match device {
329 Device::Cpu => {
330 let cpu_storage = match dtype {
331 DType::F6E2M3 => crate::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
332 DType::F6E3M2 => crate::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
333 DType::F4 => crate::cpu_backend::CpuStorage::F4(data.to_vec()),
334 DType::F8E8M0 => crate::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
335 _ => unreachable!(),
336 };
337 Storage::Cpu(cpu_storage)
338 }
339 #[cfg(feature = "cuda")]
340 Device::Cuda(device) => {
341 let mut slice = unsafe { device.alloc::<u8>(data.len())? };
342 device.memcpy_htod(data, &mut slice)?;
343
344 let slice = match dtype {
345 DType::F6E2M3 => crate::cuda_backend::CudaStorageSlice::F6E2M3(slice),
346 DType::F6E3M2 => crate::cuda_backend::CudaStorageSlice::F6E3M2(slice),
347 DType::F4 => crate::cuda_backend::CudaStorageSlice::F4(slice),
348 DType::F8E8M0 => crate::cuda_backend::CudaStorageSlice::F8E8M0(slice),
349 _ => unreachable!(),
350 };
351 let storage = crate::cuda_backend::CudaStorage {
352 slice,
353 device: device.clone(),
354 };
355 Storage::Cuda(storage)
356 }
357 #[cfg(not(feature = "cuda"))]
358 Device::Cuda(_) => {
359 return Err(Error::Msg("CUDA support not compiled".to_string()));
360 }
361 #[cfg(feature = "metal")]
362 Device::Metal(device) => {
363 let buffer = device.new_buffer_with_data(data)?;
364
365 let storage =
366 crate::metal_backend::MetalStorage::new(buffer, device.clone(), data.len(), dtype);
367 Storage::Metal(storage)
368 }
369 #[cfg(not(feature = "metal"))]
370 Device::Metal(_) => {
371 return Err(Error::Msg("Metal support not compiled".to_string()));
372 }
373 };
374
375 let op = BackpropOp::none();
377 Ok(from_storage(storage, shape, op, false))
378}
379
380fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
381 let tensor = tensor.flatten_all()?;
383 match tensor.dtype() {
384 DType::U8 => Ok(convert_back_::<u8>(tensor.to_vec1()?)),
385 DType::U32 => Ok(convert_back_::<u32>(tensor.to_vec1()?)),
386 DType::I16 => Ok(convert_back_::<i16>(tensor.to_vec1()?)),
387 DType::I32 => Ok(convert_back_::<i32>(tensor.to_vec1()?)),
388 DType::I64 => Ok(convert_back_::<i64>(tensor.to_vec1()?)),
389 DType::F16 => Ok(convert_back_::<half::f16>(tensor.to_vec1()?)),
390 DType::BF16 => Ok(convert_back_::<half::bf16>(tensor.to_vec1()?)),
391 DType::F32 => Ok(convert_back_::<f32>(tensor.to_vec1()?)),
392 DType::F64 => Ok(convert_back_::<f64>(tensor.to_vec1()?)),
393 DType::F8E4M3 => Ok(convert_back_::<float8::F8E4M3>(tensor.to_vec1()?)),
394 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
395 Err(Error::Msg("Internal error: dtype mismatch in storage".to_string()).bt())
396 }
397 }
398}
399
400pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
401 let data = std::fs::read(filename.as_ref())?;
402 load_buffer(&data[..], device)
403}
404
405pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
406 let st = safetensors::SafeTensors::deserialize(data)?;
407 st.tensors()
408 .into_iter()
409 .map(|(name, view)| Ok((name, view.load(device)?)))
410 .collect()
411}
412
413pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
414 tensors: &HashMap<K, Tensor>,
415 filename: P,
416) -> Result<()> {
417 Ok(st::serialize_to_file(tensors, None, filename.as_ref())?)
418}
419
420#[derive(yoke::Yokeable)]
421struct SafeTensors_<'a>(SafeTensors<'a>);
422
423pub struct MmapedSafetensors {
424 safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
425 routing: Option<HashMap<String, usize>>,
426}
427
428impl MmapedSafetensors {
429 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
435 let p = p.as_ref();
436 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
437 let file = memmap2::MmapOptions::new()
438 .map(&file)
439 .map_err(|e| Error::from(e).with_path(p))?;
440 let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
441 file,
442 |data: &[u8]| {
443 let st = safetensors::SafeTensors::deserialize(data)
444 .map_err(|e| Error::from(e).with_path(p))?;
445 Ok::<_, Error>(SafeTensors_(st))
446 },
447 )?;
448 Ok(Self {
449 safetensors: vec![safetensors],
450 routing: None,
451 })
452 }
453
454 pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
462 let mut routing = HashMap::new();
463 let mut safetensors = vec![];
464 for (index, p) in paths.iter().enumerate() {
465 let p = p.as_ref();
466 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
467 let file = memmap2::MmapOptions::new()
468 .map(&file)
469 .map_err(|e| Error::from(e).with_path(p))?;
470 let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
471 file,
472 |data: &[u8]| {
473 let st = safetensors::SafeTensors::deserialize(data)
474 .map_err(|e| Error::from(e).with_path(p))?;
475 Ok::<_, Error>(SafeTensors_(st))
476 },
477 )?;
478 for k in data.get().0.names() {
479 routing.insert(k.to_string(), index);
480 }
481 safetensors.push(data)
482 }
483 Ok(Self {
484 safetensors,
485 routing: Some(routing),
486 })
487 }
488
489 pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
490 self.get(name)?.load(dev)
491 }
492
493 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
494 let mut tensors = vec![];
495 for safetensors in self.safetensors.iter() {
496 tensors.push(safetensors.get().0.tensors())
497 }
498 tensors.into_iter().flatten().collect()
499 }
500
501 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
502 let index = match &self.routing {
503 None => 0,
504 Some(routing) => {
505 let index = routing.get(name).ok_or_else(|| {
506 Error::CannotFindTensor {
507 path: name.to_string(),
508 }
509 .bt()
510 })?;
511 *index
512 }
513 };
514 Ok(self.safetensors[index].get().0.tensor(name)?)
515 }
516}
517
518pub struct SliceSafetensors<'a> {
519 safetensors: SafeTensors<'a>,
520}
521
522impl<'a> SliceSafetensors<'a> {
523 pub fn new(buffer: &'a [u8]) -> Result<Self> {
525 let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
526 Ok(Self { safetensors })
527 }
528
529 pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
530 self.safetensors.tensor(name)?.load(dev)
531 }
532
533 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
534 self.safetensors.tensors()
535 }
536
537 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
538 Ok(self.safetensors.tensor(name)?)
539 }
540}
541
542pub struct BufferedSafetensors {
543 safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
544}
545
546impl BufferedSafetensors {
547 pub fn new(buffer: Vec<u8>) -> Result<Self> {
549 let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
550 buffer,
551 |data: &[u8]| {
552 let st = safetensors::SafeTensors::deserialize(data)?;
553 Ok::<_, Error>(SafeTensors_(st))
554 },
555 )?;
556 Ok(Self { safetensors })
557 }
558
559 pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
560 self.get(name)?.load(dev)
561 }
562
563 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
564 self.safetensors.get().0.tensors()
565 }
566
567 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
568 Ok(self.safetensors.get().0.tensor(name)?)
569 }
570}
571
572pub struct MmapedFile {
573 path: std::path::PathBuf,
574 inner: memmap2::Mmap,
575}
576
577impl MmapedFile {
578 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
585 let p = p.as_ref();
586 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
587 let inner = memmap2::MmapOptions::new()
588 .map(&file)
589 .map_err(|e| Error::from(e).with_path(p))?;
590 Ok(Self {
591 inner,
592 path: p.to_path_buf(),
593 })
594 }
595
596 pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
597 let st = safetensors::SafeTensors::deserialize(&self.inner)
598 .map_err(|e| Error::from(e).with_path(&self.path))?;
599 Ok(st)
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606 use std::collections::HashMap;
607
608 #[test]
609 fn save_single_tensor() {
610 let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
611 t.save_safetensors("t", "t.safetensors").unwrap();
612 let bytes = std::fs::read("t.safetensors").unwrap();
613 assert_eq!(bytes, b"@\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
614 std::fs::remove_file("t.safetensors").unwrap();
615 }
616
617 #[test]
618 fn save_load_multiple_tensors() {
619 let t = Tensor::zeros((2, 2), DType::F32, &Device::Cpu).unwrap();
620 let u = Tensor::zeros((1, 2), DType::F32, &Device::Cpu).unwrap();
621 let map: HashMap<_, _> = [("t", t), ("u", u)].into_iter().collect();
622 save(&map, "multi.safetensors").unwrap();
623
624 let weights = load("multi.safetensors", &Device::Cpu).unwrap();
625 assert_eq!(weights.get("t").unwrap().dims(), &[2, 2]);
626 assert_eq!(weights.get("u").unwrap().dims(), &[1, 2]);
627 let bytes = std::fs::read("multi.safetensors").unwrap();
628 assert_eq!(bytes, b"x\0\0\0\0\0\0\0{\"t\":{\"dtype\":\"F32\",\"shape\":[2,2],\"data_offsets\":[0,16]},\"u\":{\"dtype\":\"F32\",\"shape\":[1,2],\"data_offsets\":[16,24]}} \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
629 std::fs::remove_file("multi.safetensors").unwrap();
630 }
631
632 #[test]
633 fn load_u8() {
634 let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"U8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03";
635 std::fs::write("test_u8.safetensors", bytes).unwrap();
636 let weights = load("test_u8.safetensors", &Device::Cpu).unwrap();
637 let tensor = weights.get("x").unwrap();
638 assert_eq!(tensor.dims(), &[2]);
639 assert_eq!(tensor.dtype(), DType::U8);
640 let data: Vec<u8> = tensor.to_vec1().unwrap();
641 assert_eq!(data, vec![1, 3]);
642 std::fs::remove_file("test_u8.safetensors").unwrap();
643 }
644}