1use crate::VarMap;
7use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
8use safetensors::{slice::IndexOp, tensor::SafeTensors};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12pub struct VarBuilderArgs<'a, B: Backend> {
17 data: Arc<TensorData<B>>,
18 path: Vec<String>,
19 pub dtype: DType,
20 _phantom: std::marker::PhantomData<&'a B>,
21}
22
23impl<B: Backend> Clone for VarBuilderArgs<'_, B> {
24 fn clone(&self) -> Self {
25 Self {
26 data: self.data.clone(),
27 path: self.path.clone(),
28 dtype: self.dtype,
29 _phantom: self._phantom,
30 }
31 }
32}
33
34pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
37
38struct TensorData<B: Backend> {
39 backend: Arc<B>,
40 pub device: Device,
41 pub dtype: DType,
42}
43
44pub trait Backend: Send + Sync {
51 type Hints: Default;
52
53 fn get(
55 &self,
56 s: Shape,
57 name: &str,
58 h: Self::Hints,
59 dtype: DType,
60 dev: &Device,
61 ) -> Result<Tensor>;
62
63 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor>;
65
66 fn contains_tensor(&self, name: &str) -> bool;
67}
68
69pub trait SimpleBackend: Send + Sync {
70 fn get(
72 &self,
73 s: Shape,
74 name: &str,
75 h: crate::Init,
76 dtype: DType,
77 dev: &Device,
78 ) -> Result<Tensor>;
79
80 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor>;
82
83 fn contains_tensor(&self, name: &str) -> bool;
84}
85
86impl Backend for Box<dyn SimpleBackend + '_> {
87 type Hints = crate::Init;
88 fn get(
89 &self,
90 s: Shape,
91 name: &str,
92 h: Self::Hints,
93 dtype: DType,
94 dev: &Device,
95 ) -> Result<Tensor> {
96 self.as_ref().get(s, name, h, dtype, dev)
97 }
98
99 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
100 self.as_ref().get_unchecked(name, dtype, dev)
101 }
102
103 fn contains_tensor(&self, name: &str) -> bool {
104 self.as_ref().contains_tensor(name)
105 }
106}
107
108impl<B: Backend> VarBuilderArgs<'_, B> {
109 pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
110 let data = TensorData {
111 backend: Arc::new(backend),
112 device: dev.clone(),
113 dtype,
114 };
115 Self {
116 data: Arc::new(data),
117 path: vec![],
118 dtype,
119 _phantom: std::marker::PhantomData,
120 }
121 }
122
123 pub fn prefix(&self) -> String {
125 self.path.join(".")
126 }
127
128 pub fn root(&self) -> Self {
130 Self {
131 data: self.data.clone(),
132 path: vec![],
133 dtype: self.dtype,
134 _phantom: std::marker::PhantomData,
135 }
136 }
137
138 pub fn set_prefix(&self, prefix: impl ToString) -> Self {
140 Self {
141 data: self.data.clone(),
142 path: vec![prefix.to_string()],
143 dtype: self.dtype,
144 _phantom: std::marker::PhantomData,
145 }
146 }
147
148 pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
151 let mut path = self.path.clone();
152 path.push(s.to_string());
153 Self {
154 data: self.data.clone(),
155 path,
156 dtype: self.dtype,
157 _phantom: std::marker::PhantomData,
158 }
159 }
160
161 pub fn pp<S: ToString>(&self, s: S) -> Self {
163 self.push_prefix(s)
164 }
165
166 pub fn device(&self) -> &Device {
168 &self.data.device
169 }
170
171 pub fn dtype(&self) -> DType {
173 self.dtype
174 }
175
176 pub fn to_dtype(&self, dtype: DType) -> Self {
178 Self {
179 data: self.data.clone(),
180 path: self.path.clone(),
181 dtype,
182 _phantom: std::marker::PhantomData,
183 }
184 }
185
186 fn path(&self, tensor_name: &str) -> String {
187 if self.path.is_empty() {
188 tensor_name.to_string()
189 } else {
190 [&self.path.join("."), tensor_name].join(".")
191 }
192 }
193
194 pub fn contains_tensor(&self, tensor_name: &str) -> bool {
198 let path = self.path(tensor_name);
199 self.data.backend.contains_tensor(&path)
200 }
201
202 pub fn get_with_hints<S: Into<Shape>>(
204 &self,
205 s: S,
206 name: &str,
207 hints: B::Hints,
208 ) -> Result<Tensor> {
209 self.get_with_hints_dtype(s, name, hints, self.dtype)
210 }
211
212 pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
214 self.get_with_hints(s, name, Default::default())
215 }
216
217 pub fn get_unchecked(&self, name: &str) -> Result<Tensor> {
219 self.get_unchecked_dtype(name, self.data.dtype)
220 }
221
222 pub fn get_unchecked_dtype(&self, name: &str, dtype: DType) -> Result<Tensor> {
224 let name = self.path(name);
225 self.data
226 .backend
227 .get_unchecked(&name, dtype, &self.data.device)
228 }
229
230 pub fn get_with_hints_dtype<S: Into<Shape>>(
232 &self,
233 s: S,
234 name: &str,
235 hints: B::Hints,
236 dtype: DType,
237 ) -> Result<Tensor> {
238 let path = self.path(name);
239 self.data
240 .backend
241 .get(s.into(), &path, hints, dtype, &self.data.device)
242 }
243
244 pub fn set_device(self, device: Device) -> Self {
246 Self {
247 data: Arc::new(TensorData {
248 backend: self.data.backend.clone(),
249 dtype: self.data.dtype,
250 device,
251 }),
252 ..self
253 }
254 }
255
256 pub fn set_dtype(self, dtype: DType) -> Self {
258 Self {
259 data: Arc::new(TensorData {
260 backend: self.data.backend.clone(),
261 dtype,
262 device: self.data.device.clone(),
263 }),
264 dtype,
265 ..self
266 }
267 }
268}
269
270struct Zeros;
271
272impl SimpleBackend for Zeros {
273 fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
274 Tensor::zeros(s, dtype, dev)
275 }
276
277 fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
278 candle::bail!(
279 "`Zeros` requires a shape for tensor retrieval, use `get` instead of `get_unchecked`"
280 )
281 }
282
283 fn contains_tensor(&self, _name: &str) -> bool {
284 true
285 }
286}
287
288impl SimpleBackend for HashMap<String, Tensor> {
289 fn get(
290 &self,
291 s: Shape,
292 name: &str,
293 _: crate::Init,
294 dtype: DType,
295 dev: &Device,
296 ) -> Result<Tensor> {
297 let tensor = self
298 .get(name)
299 .ok_or_else(|| {
300 Error::CannotFindTensor {
301 path: name.to_string(),
302 }
303 .bt()
304 })?
305 .clone();
306 if tensor.shape() != &s {
307 Err(candle::Error::UnexpectedShape {
308 msg: format!("shape mismatch for {name}"),
309 expected: s,
310 got: tensor.shape().clone(),
311 }
312 .bt())?
313 }
314 tensor.to_device(dev)?.to_dtype(dtype)
315 }
316
317 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
318 let tensor = self
319 .get(name)
320 .ok_or_else(|| {
321 Error::CannotFindTensor {
322 path: name.to_string(),
323 }
324 .bt()
325 })?
326 .clone();
327 tensor.to_device(dev)?.to_dtype(dtype)
328 }
329
330 fn contains_tensor(&self, name: &str) -> bool {
331 self.contains_key(name)
332 }
333}
334
335impl SimpleBackend for VarMap {
336 fn get(
337 &self,
338 s: Shape,
339 name: &str,
340 h: crate::Init,
341 dtype: DType,
342 dev: &Device,
343 ) -> Result<Tensor> {
344 VarMap::get(self, s, name, h, dtype, dev)
345 }
346
347 fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
348 candle::bail!("`get_unchecked` does not make sense for `VarMap`, use `get`.");
349 }
350
351 fn contains_tensor(&self, name: &str) -> bool {
352 self.data().lock().unwrap().contains_key(name)
353 }
354}
355
356#[allow(dead_code)]
357pub struct SafeTensorWithRouting<'a> {
358 routing: HashMap<String, usize>,
359 safetensors: Vec<SafeTensors<'a>>,
360}
361
362impl SimpleBackend for SafeTensorWithRouting<'_> {
363 fn get(
364 &self,
365 s: Shape,
366 path: &str,
367 _: crate::Init,
368 dtype: DType,
369 dev: &Device,
370 ) -> Result<Tensor> {
371 let index = self.routing.get(path).ok_or_else(|| {
372 Error::CannotFindTensor {
373 path: path.to_string(),
374 }
375 .bt()
376 })?;
377 let tensor = self.safetensors[*index]
378 .tensor(path)?
379 .load(dev)?
380 .to_dtype(dtype)?;
381 if tensor.shape() != &s {
382 Err(candle::Error::UnexpectedShape {
383 msg: format!("shape mismatch for {path}"),
384 expected: s,
385 got: tensor.shape().clone(),
386 }
387 .bt())?
388 }
389 Ok(tensor)
390 }
391
392 fn get_unchecked(&self, path: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
393 let index = self.routing.get(path).ok_or_else(|| {
394 Error::CannotFindTensor {
395 path: path.to_string(),
396 }
397 .bt()
398 })?;
399 let tensor = self.safetensors[*index]
400 .tensor(path)?
401 .load(dev)?
402 .to_dtype(dtype)?;
403 Ok(tensor)
404 }
405
406 fn contains_tensor(&self, name: &str) -> bool {
407 self.routing.contains_key(name)
408 }
409}
410
411impl SimpleBackend for candle::npy::NpzTensors {
412 fn get(
413 &self,
414 s: Shape,
415 path: &str,
416 _: crate::Init,
417 dtype: DType,
418 dev: &Device,
419 ) -> Result<Tensor> {
420 let tensor = match self.get(path)? {
421 None => Err(Error::CannotFindTensor {
422 path: path.to_string(),
423 }
424 .bt())?,
425 Some(tensor) => tensor,
426 };
427 let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
428 if tensor.shape() != &s {
429 Err(candle::Error::UnexpectedShape {
430 msg: format!("shape mismatch for {path}"),
431 expected: s,
432 got: tensor.shape().clone(),
433 }
434 .bt())?
435 }
436 Ok(tensor)
437 }
438
439 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
440 let tensor = match self.get(name)? {
441 None => Err(Error::CannotFindTensor {
442 path: name.to_string(),
443 }
444 .bt())?,
445 Some(tensor) => tensor,
446 };
447 let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
448 Ok(tensor)
449 }
450
451 fn contains_tensor(&self, name: &str) -> bool {
452 self.get(name).is_ok_and(|v| v.is_some())
453 }
454}
455
456impl SimpleBackend for candle::pickle::PthTensors {
457 fn get(
458 &self,
459 s: Shape,
460 path: &str,
461 _: crate::Init,
462 dtype: DType,
463 dev: &Device,
464 ) -> Result<Tensor> {
465 let tensor = match self.get(path)? {
466 None => Err(Error::CannotFindTensor {
467 path: path.to_string(),
468 }
469 .bt())?,
470 Some(tensor) => tensor,
471 };
472 let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
473 if tensor.shape() != &s {
474 Err(candle::Error::UnexpectedShape {
475 msg: format!("shape mismatch for {path}"),
476 expected: s,
477 got: tensor.shape().clone(),
478 }
479 .bt())?
480 }
481 Ok(tensor)
482 }
483
484 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
485 let tensor = match self.get(name)? {
486 None => Err(Error::CannotFindTensor {
487 path: name.to_string(),
488 }
489 .bt())?,
490 Some(tensor) => tensor,
491 };
492 let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
493 Ok(tensor)
494 }
495
496 fn contains_tensor(&self, name: &str) -> bool {
497 self.get(name).is_ok_and(|v| v.is_some())
498 }
499}
500
501impl SimpleBackend for candle::safetensors::MmapedSafetensors {
502 fn get(
503 &self,
504 s: Shape,
505 name: &str,
506 _: crate::Init,
507 dtype: DType,
508 dev: &Device,
509 ) -> Result<Tensor> {
510 let tensor = self.load(name, dev)?.to_dtype(dtype)?;
511 if tensor.shape() != &s {
512 Err(candle::Error::UnexpectedShape {
513 msg: format!("shape mismatch for {name}"),
514 expected: s,
515 got: tensor.shape().clone(),
516 }
517 .bt())?
518 }
519 Ok(tensor)
520 }
521
522 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
523 self.load(name, dev)?.to_dtype(dtype)
524 }
525
526 fn contains_tensor(&self, name: &str) -> bool {
527 self.get(name).is_ok()
528 }
529}
530
531impl SimpleBackend for candle::safetensors::BufferedSafetensors {
532 fn get(
533 &self,
534 s: Shape,
535 name: &str,
536 _: crate::Init,
537 dtype: DType,
538 dev: &Device,
539 ) -> Result<Tensor> {
540 let tensor = self.load(name, dev)?.to_dtype(dtype)?;
541 if tensor.shape() != &s {
542 Err(candle::Error::UnexpectedShape {
543 msg: format!("shape mismatch for {name}"),
544 expected: s,
545 got: tensor.shape().clone(),
546 }
547 .bt())?
548 }
549 Ok(tensor)
550 }
551
552 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
553 self.load(name, dev)?.to_dtype(dtype)
554 }
555
556 fn contains_tensor(&self, name: &str) -> bool {
557 self.get(name).is_ok()
558 }
559}
560
561impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
562 fn get(
563 &self,
564 s: Shape,
565 name: &str,
566 _: crate::Init,
567 dtype: DType,
568 dev: &Device,
569 ) -> Result<Tensor> {
570 let tensor = self.load(name, dev)?.to_dtype(dtype)?;
571 if tensor.shape() != &s {
572 Err(candle::Error::UnexpectedShape {
573 msg: format!("shape mismatch for {name}"),
574 expected: s,
575 got: tensor.shape().clone(),
576 }
577 .bt())?
578 }
579 Ok(tensor)
580 }
581
582 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
583 self.load(name, dev)?.to_dtype(dtype)
584 }
585
586 fn contains_tensor(&self, name: &str) -> bool {
587 self.get(name).is_ok()
588 }
589}
590
591impl<'a> VarBuilder<'a> {
592 pub fn from_backend(
598 backend: Box<dyn SimpleBackend + 'a>,
599 dtype: DType,
600 device: Device,
601 ) -> Self {
602 let data = TensorData {
603 backend: Arc::new(backend),
604 device,
605 dtype,
606 };
607 Self {
608 data: Arc::new(data),
609 path: vec![],
610 dtype,
611 _phantom: std::marker::PhantomData,
612 }
613 }
614
615 pub fn zeros(dtype: DType, dev: &Device) -> Self {
617 Self::from_backend(Box::new(Zeros), dtype, dev.clone())
618 }
619
620 pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
623 Self::from_backend(Box::new(ts), dtype, dev.clone())
624 }
625
626 pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
633 Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone())
634 }
635
636 pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
643 paths: &[P],
644 dtype: DType,
645 dev: &Device,
646 ) -> Result<Self> {
647 let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
648 Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
649 }
650
651 pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
653 let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
654 Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
655 }
656
657 pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
659 let tensors = candle::safetensors::SliceSafetensors::new(data)?;
660 Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
661 }
662
663 pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
665 let npz = candle::npy::NpzTensors::new(p)?;
666 Ok(Self::from_backend(Box::new(npz), dtype, dev.clone()))
667 }
668
669 pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
671 let pth = candle::pickle::PthTensors::new(p, None)?;
672 Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
673 }
674 pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
677 p: P,
678 dtype: DType,
679 state_key: &str,
680 dev: &Device,
681 ) -> Result<Self> {
682 let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
683 Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
684 }
685 pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self {
710 let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f);
711 self.rename(f)
712 }
713
714 pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self {
715 let dtype = self.dtype();
716 let device = self.device().clone();
717 let path = self.path.clone();
718 let backend = Rename::new(self, renamer);
719 let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
720 let data = TensorData {
721 backend: Arc::new(backend),
722 device,
723 dtype,
724 };
725 Self {
726 data: Arc::new(data),
727 dtype,
728 path,
729 _phantom: std::marker::PhantomData,
730 }
731 }
732}
733
734pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
735
736pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
737
738impl ShardedSafeTensors {
739 pub unsafe fn var_builder<P: AsRef<std::path::Path>>(
746 paths: &[P],
747 dtype: DType,
748 dev: &Device,
749 ) -> Result<ShardedVarBuilder<'static>> {
750 let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
751 let backend = ShardedSafeTensors(tensors);
752 Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
753 }
754}
755
756#[derive(Debug, Clone, Copy, Eq, PartialEq)]
757pub struct Shard {
758 pub dim: usize,
759 pub rank: usize,
760 pub world_size: usize,
761}
762
763impl Default for Shard {
764 fn default() -> Self {
765 Self {
766 dim: 0,
767 rank: 0,
768 world_size: 1,
769 }
770 }
771}
772
773impl Backend for ShardedSafeTensors {
785 type Hints = Shard;
786
787 fn get(
788 &self,
789 target_shape: Shape, path: &str,
791 h: Self::Hints,
792 dtype: DType,
793 dev: &Device,
794 ) -> Result<Tensor> {
795 if h.world_size == 1 {
796 return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev);
799 }
800
801 let Shard {
802 dim,
803 rank,
804 world_size,
805 } = h;
806 let view = self.0.get(path)?;
807 let view_dtype = view.dtype();
808 let mut shape = view.shape().to_vec();
809 let size = shape[dim];
810
811 if size % world_size != 0 {
812 return Err(Error::ShapeMismatchSplit {
813 shape: shape.into(),
814 dim,
815 n_parts: world_size,
816 });
817 }
818 let block_size = size / world_size;
819 let start = rank * block_size;
820 let stop = (rank + 1) * block_size;
821
822 let iterator = if dim == 0 {
826 view.slice(start..stop).map_err(|_| {
827 Error::Msg(format!(
828 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
829 ))
830 })?
831 } else if dim == 1 {
832 view.slice((.., start..stop)).map_err(|_| {
833 Error::Msg(format!(
834 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
835 ))
836 })?
837 } else {
838 candle::bail!("Get sharded on dimensions != 0 or 1")
839 };
840
841 shape[dim] = block_size;
842
843 let view_dtype: DType = view_dtype.try_into()?;
844 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
845 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
846 }
847
848 fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
849 candle::bail!("`get_unchecked` does not make sense for `ShardedSafeTensors`, use `get`.");
850 }
851
852 fn contains_tensor(&self, name: &str) -> bool {
853 self.0.get(name).is_ok()
854 }
855}
856
857pub trait Renamer {
860 fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>;
863}
864
865pub struct Rename<'a, R: Renamer> {
866 inner: VarBuilder<'a>,
867 renamer: R,
868}
869
870impl<R: Renamer + Sync + Send> SimpleBackend for Rename<'_, R> {
871 fn get(
872 &self,
873 s: Shape,
874 name: &str,
875 h: crate::Init,
876 dtype: DType,
877 dev: &Device,
878 ) -> Result<Tensor> {
879 let name = self.renamer.rename(name);
880 self.inner
881 .get_with_hints_dtype(s, &name, h, dtype)?
882 .to_device(dev)
883 }
884
885 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
886 let name = self.renamer.rename(name);
887 self.inner.get_unchecked_dtype(&name, dtype)?.to_device(dev)
888 }
889
890 fn contains_tensor(&self, name: &str) -> bool {
891 let name = self.renamer.rename(name);
892 self.inner.contains_tensor(&name)
893 }
894}
895
896impl<'a, R: Renamer> Rename<'a, R> {
897 pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self {
898 Self { inner, renamer }
899 }
900}
901
902impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> {
903 fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> {
904 std::borrow::Cow::Owned(self(v))
905 }
906}