1use crate::VarMap;
7use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
8use safetensors::{slice::IndexOp, tensor::SafeTensors};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12pub struct VarBuilderArgs<'a, B: Backend> {
17 data: Arc<TensorData<B>>,
18 path: Vec<String>,
19 pub dtype: DType,
20 _phantom: std::marker::PhantomData<&'a B>,
21}
22
23impl<B: Backend> Clone for VarBuilderArgs<'_, B> {
24 fn clone(&self) -> Self {
25 Self {
26 data: self.data.clone(),
27 path: self.path.clone(),
28 dtype: self.dtype,
29 _phantom: self._phantom,
30 }
31 }
32}
33
34pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
37
38struct TensorData<B: Backend> {
39 backend: B,
40 pub device: Device,
41}
42
43pub trait Backend: Send + Sync {
50 type Hints: Default;
51
52 fn get(
54 &self,
55 s: Shape,
56 name: &str,
57 h: Self::Hints,
58 dtype: DType,
59 dev: &Device,
60 ) -> Result<Tensor>;
61
62 fn contains_tensor(&self, name: &str) -> bool;
63}
64
65pub trait SimpleBackend: Send + Sync {
66 fn get(
68 &self,
69 s: Shape,
70 name: &str,
71 h: crate::Init,
72 dtype: DType,
73 dev: &Device,
74 ) -> Result<Tensor>;
75
76 fn contains_tensor(&self, name: &str) -> bool;
77}
78
79impl Backend for Box<dyn SimpleBackend + '_> {
80 type Hints = crate::Init;
81 fn get(
82 &self,
83 s: Shape,
84 name: &str,
85 h: Self::Hints,
86 dtype: DType,
87 dev: &Device,
88 ) -> Result<Tensor> {
89 self.as_ref().get(s, name, h, dtype, dev)
90 }
91
92 fn contains_tensor(&self, name: &str) -> bool {
93 self.as_ref().contains_tensor(name)
94 }
95}
96
97impl<B: Backend> VarBuilderArgs<'_, B> {
98 pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
99 let data = TensorData {
100 backend,
101 device: dev.clone(),
102 };
103 Self {
104 data: Arc::new(data),
105 path: vec![],
106 dtype,
107 _phantom: std::marker::PhantomData,
108 }
109 }
110
111 pub fn prefix(&self) -> String {
113 self.path.join(".")
114 }
115
116 pub fn root(&self) -> Self {
118 Self {
119 data: self.data.clone(),
120 path: vec![],
121 dtype: self.dtype,
122 _phantom: std::marker::PhantomData,
123 }
124 }
125
126 pub fn set_prefix(&self, prefix: impl ToString) -> Self {
128 Self {
129 data: self.data.clone(),
130 path: vec![prefix.to_string()],
131 dtype: self.dtype,
132 _phantom: std::marker::PhantomData,
133 }
134 }
135
136 pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
139 let mut path = self.path.clone();
140 path.push(s.to_string());
141 Self {
142 data: self.data.clone(),
143 path,
144 dtype: self.dtype,
145 _phantom: std::marker::PhantomData,
146 }
147 }
148
149 pub fn pp<S: ToString>(&self, s: S) -> Self {
151 self.push_prefix(s)
152 }
153
154 pub fn device(&self) -> &Device {
156 &self.data.device
157 }
158
159 pub fn dtype(&self) -> DType {
161 self.dtype
162 }
163
164 pub fn to_dtype(&self, dtype: DType) -> Self {
166 Self {
167 data: self.data.clone(),
168 path: self.path.clone(),
169 dtype,
170 _phantom: std::marker::PhantomData,
171 }
172 }
173
174 fn path(&self, tensor_name: &str) -> String {
175 if self.path.is_empty() {
176 tensor_name.to_string()
177 } else {
178 [&self.path.join("."), tensor_name].join(".")
179 }
180 }
181
182 pub fn contains_tensor(&self, tensor_name: &str) -> bool {
186 let path = self.path(tensor_name);
187 self.data.backend.contains_tensor(&path)
188 }
189
190 pub fn get_with_hints<S: Into<Shape>>(
192 &self,
193 s: S,
194 name: &str,
195 hints: B::Hints,
196 ) -> Result<Tensor> {
197 self.get_with_hints_dtype(s, name, hints, self.dtype)
198 }
199
200 pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
202 self.get_with_hints(s, name, Default::default())
203 }
204
205 pub fn get_with_hints_dtype<S: Into<Shape>>(
207 &self,
208 s: S,
209 name: &str,
210 hints: B::Hints,
211 dtype: DType,
212 ) -> Result<Tensor> {
213 let path = self.path(name);
214 self.data
215 .backend
216 .get(s.into(), &path, hints, dtype, &self.data.device)
217 }
218}
219
220struct Zeros;
221
222impl SimpleBackend for Zeros {
223 fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
224 Tensor::zeros(s, dtype, dev)
225 }
226
227 fn contains_tensor(&self, _name: &str) -> bool {
228 true
229 }
230}
231
232impl SimpleBackend for HashMap<String, Tensor> {
233 fn get(
234 &self,
235 s: Shape,
236 name: &str,
237 _: crate::Init,
238 dtype: DType,
239 dev: &Device,
240 ) -> Result<Tensor> {
241 let tensor = self
242 .get(name)
243 .ok_or_else(|| {
244 Error::CannotFindTensor {
245 path: name.to_string(),
246 }
247 .bt()
248 })?
249 .clone();
250 if tensor.shape() != &s {
251 Err(candle::Error::UnexpectedShape {
252 msg: format!("shape mismatch for {name}"),
253 expected: s,
254 got: tensor.shape().clone(),
255 }
256 .bt())?
257 }
258 tensor.to_device(dev)?.to_dtype(dtype)
259 }
260
261 fn contains_tensor(&self, name: &str) -> bool {
262 self.contains_key(name)
263 }
264}
265
266impl SimpleBackend for VarMap {
267 fn get(
268 &self,
269 s: Shape,
270 name: &str,
271 h: crate::Init,
272 dtype: DType,
273 dev: &Device,
274 ) -> Result<Tensor> {
275 VarMap::get(self, s, name, h, dtype, dev)
276 }
277
278 fn contains_tensor(&self, name: &str) -> bool {
279 self.data().lock().unwrap().contains_key(name)
280 }
281}
282
283#[allow(dead_code)]
284pub struct SafeTensorWithRouting<'a> {
285 routing: HashMap<String, usize>,
286 safetensors: Vec<SafeTensors<'a>>,
287}
288
289impl SimpleBackend for SafeTensorWithRouting<'_> {
290 fn get(
291 &self,
292 s: Shape,
293 path: &str,
294 _: crate::Init,
295 dtype: DType,
296 dev: &Device,
297 ) -> Result<Tensor> {
298 let index = self.routing.get(path).ok_or_else(|| {
299 Error::CannotFindTensor {
300 path: path.to_string(),
301 }
302 .bt()
303 })?;
304 let tensor = self.safetensors[*index]
305 .tensor(path)?
306 .load(dev)?
307 .to_dtype(dtype)?;
308 if tensor.shape() != &s {
309 Err(candle::Error::UnexpectedShape {
310 msg: format!("shape mismatch for {path}"),
311 expected: s,
312 got: tensor.shape().clone(),
313 }
314 .bt())?
315 }
316 Ok(tensor)
317 }
318
319 fn contains_tensor(&self, name: &str) -> bool {
320 self.routing.contains_key(name)
321 }
322}
323
324impl SimpleBackend for candle::npy::NpzTensors {
325 fn get(
326 &self,
327 s: Shape,
328 path: &str,
329 _: crate::Init,
330 dtype: DType,
331 dev: &Device,
332 ) -> Result<Tensor> {
333 let tensor = match self.get(path)? {
334 None => Err(Error::CannotFindTensor {
335 path: path.to_string(),
336 }
337 .bt())?,
338 Some(tensor) => tensor,
339 };
340 let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
341 if tensor.shape() != &s {
342 Err(candle::Error::UnexpectedShape {
343 msg: format!("shape mismatch for {path}"),
344 expected: s,
345 got: tensor.shape().clone(),
346 }
347 .bt())?
348 }
349 Ok(tensor)
350 }
351
352 fn contains_tensor(&self, name: &str) -> bool {
353 self.get(name).is_ok_and(|v| v.is_some())
354 }
355}
356
357impl SimpleBackend for candle::pickle::PthTensors {
358 fn get(
359 &self,
360 s: Shape,
361 path: &str,
362 _: crate::Init,
363 dtype: DType,
364 dev: &Device,
365 ) -> Result<Tensor> {
366 let tensor = match self.get(path)? {
367 None => Err(Error::CannotFindTensor {
368 path: path.to_string(),
369 }
370 .bt())?,
371 Some(tensor) => tensor,
372 };
373 let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
374 if tensor.shape() != &s {
375 Err(candle::Error::UnexpectedShape {
376 msg: format!("shape mismatch for {path}"),
377 expected: s,
378 got: tensor.shape().clone(),
379 }
380 .bt())?
381 }
382 Ok(tensor)
383 }
384
385 fn contains_tensor(&self, name: &str) -> bool {
386 self.get(name).is_ok_and(|v| v.is_some())
387 }
388}
389
390impl SimpleBackend for candle::safetensors::MmapedSafetensors {
391 fn get(
392 &self,
393 s: Shape,
394 name: &str,
395 _: crate::Init,
396 dtype: DType,
397 dev: &Device,
398 ) -> Result<Tensor> {
399 let tensor = self.load(name, dev)?.to_dtype(dtype)?;
400 if tensor.shape() != &s {
401 Err(candle::Error::UnexpectedShape {
402 msg: format!("shape mismatch for {name}"),
403 expected: s,
404 got: tensor.shape().clone(),
405 }
406 .bt())?
407 }
408 Ok(tensor)
409 }
410
411 fn contains_tensor(&self, name: &str) -> bool {
412 self.get(name).is_ok()
413 }
414}
415
416impl SimpleBackend for candle::safetensors::BufferedSafetensors {
417 fn get(
418 &self,
419 s: Shape,
420 name: &str,
421 _: crate::Init,
422 dtype: DType,
423 dev: &Device,
424 ) -> Result<Tensor> {
425 let tensor = self.load(name, dev)?.to_dtype(dtype)?;
426 if tensor.shape() != &s {
427 Err(candle::Error::UnexpectedShape {
428 msg: format!("shape mismatch for {name}"),
429 expected: s,
430 got: tensor.shape().clone(),
431 }
432 .bt())?
433 }
434 Ok(tensor)
435 }
436
437 fn contains_tensor(&self, name: &str) -> bool {
438 self.get(name).is_ok()
439 }
440}
441
442impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
443 fn get(
444 &self,
445 s: Shape,
446 name: &str,
447 _: crate::Init,
448 dtype: DType,
449 dev: &Device,
450 ) -> Result<Tensor> {
451 let tensor = self.load(name, dev)?.to_dtype(dtype)?;
452 if tensor.shape() != &s {
453 Err(candle::Error::UnexpectedShape {
454 msg: format!("shape mismatch for {name}"),
455 expected: s,
456 got: tensor.shape().clone(),
457 }
458 .bt())?
459 }
460 Ok(tensor)
461 }
462
463 fn contains_tensor(&self, name: &str) -> bool {
464 self.get(name).is_ok()
465 }
466}
467
468impl<'a> VarBuilder<'a> {
469 pub fn from_backend(
475 backend: Box<dyn SimpleBackend + 'a>,
476 dtype: DType,
477 device: Device,
478 ) -> Self {
479 let data = TensorData { backend, device };
480 Self {
481 data: Arc::new(data),
482 path: vec![],
483 dtype,
484 _phantom: std::marker::PhantomData,
485 }
486 }
487
488 pub fn zeros(dtype: DType, dev: &Device) -> Self {
490 Self::from_backend(Box::new(Zeros), dtype, dev.clone())
491 }
492
493 pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
496 Self::from_backend(Box::new(ts), dtype, dev.clone())
497 }
498
499 pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
506 Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone())
507 }
508
509 pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
516 paths: &[P],
517 dtype: DType,
518 dev: &Device,
519 ) -> Result<Self> {
520 let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
521 Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
522 }
523
524 pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
526 let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
527 Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
528 }
529
530 pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
532 let tensors = candle::safetensors::SliceSafetensors::new(data)?;
533 Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
534 }
535
536 pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
538 let npz = candle::npy::NpzTensors::new(p)?;
539 Ok(Self::from_backend(Box::new(npz), dtype, dev.clone()))
540 }
541
542 pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
544 let pth = candle::pickle::PthTensors::new(p, None)?;
545 Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
546 }
547 pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
550 p: P,
551 dtype: DType,
552 state_key: &str,
553 dev: &Device,
554 ) -> Result<Self> {
555 let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
556 Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
557 }
558 pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self {
583 let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f);
584 self.rename(f)
585 }
586
587 pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self {
588 let dtype = self.dtype();
589 let device = self.device().clone();
590 let path = self.path.clone();
591 let backend = Rename::new(self, renamer);
592 let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
593 let data = TensorData { backend, device };
594 Self {
595 data: Arc::new(data),
596 dtype,
597 path,
598 _phantom: std::marker::PhantomData,
599 }
600 }
601}
602
603pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
604
605pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
606
607impl ShardedSafeTensors {
608 pub unsafe fn var_builder<P: AsRef<std::path::Path>>(
615 paths: &[P],
616 dtype: DType,
617 dev: &Device,
618 ) -> Result<ShardedVarBuilder<'static>> {
619 let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
620 let backend = ShardedSafeTensors(tensors);
621 Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
622 }
623}
624
625#[derive(Debug, Clone, Copy, Eq, PartialEq)]
626pub struct Shard {
627 pub dim: usize,
628 pub rank: usize,
629 pub world_size: usize,
630}
631
632impl Default for Shard {
633 fn default() -> Self {
634 Self {
635 dim: 0,
636 rank: 0,
637 world_size: 1,
638 }
639 }
640}
641
642impl Backend for ShardedSafeTensors {
654 type Hints = Shard;
655
656 fn get(
657 &self,
658 target_shape: Shape, path: &str,
660 h: Self::Hints,
661 dtype: DType,
662 dev: &Device,
663 ) -> Result<Tensor> {
664 if h.world_size == 1 {
665 return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev);
668 }
669
670 let Shard {
671 dim,
672 rank,
673 world_size,
674 } = h;
675 let view = self.0.get(path)?;
676 let view_dtype = view.dtype();
677 let mut shape = view.shape().to_vec();
678 let size = shape[dim];
679
680 if size % world_size != 0 {
681 return Err(Error::ShapeMismatchSplit {
682 shape: shape.into(),
683 dim,
684 n_parts: world_size,
685 });
686 }
687 let block_size = size / world_size;
688 let start = rank * block_size;
689 let stop = (rank + 1) * block_size;
690
691 let iterator = if dim == 0 {
695 view.slice(start..stop).map_err(|_| {
696 Error::Msg(format!(
697 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
698 ))
699 })?
700 } else if dim == 1 {
701 view.slice((.., start..stop)).map_err(|_| {
702 Error::Msg(format!(
703 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
704 ))
705 })?
706 } else {
707 candle::bail!("Get sharded on dimensions != 0 or 1")
708 };
709
710 shape[dim] = block_size;
711
712 let view_dtype: DType = view_dtype.try_into()?;
713 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
714 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
715 }
716
717 fn contains_tensor(&self, name: &str) -> bool {
718 self.0.get(name).is_ok()
719 }
720}
721
722pub trait Renamer {
725 fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>;
728}
729
730pub struct Rename<'a, R: Renamer> {
731 inner: VarBuilder<'a>,
732 renamer: R,
733}
734
735impl<R: Renamer + Sync + Send> SimpleBackend for Rename<'_, R> {
736 fn get(
737 &self,
738 s: Shape,
739 name: &str,
740 h: crate::Init,
741 dtype: DType,
742 dev: &Device,
743 ) -> Result<Tensor> {
744 let name = self.renamer.rename(name);
745 self.inner
746 .get_with_hints_dtype(s, &name, h, dtype)?
747 .to_device(dev)
748 }
749
750 fn contains_tensor(&self, name: &str) -> bool {
751 let name = self.renamer.rename(name);
752 self.inner.contains_tensor(&name)
753 }
754}
755
756impl<'a, R: Renamer> Rename<'a, R> {
757 pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self {
758 Self { inner, renamer }
759 }
760}
761
762impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> {
763 fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> {
764 std::borrow::Cow::Owned(self(v))
765 }
766}