Skip to main content

candle_nn/
var_builder.rs

1//! A `VarBuilder` for variable retrieval from models
2//!
3//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
4//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized
5//! for training, e.g. using `VarBuilder::from_varmap`.
6use crate::VarMap;
7use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
8use safetensors::{slice::IndexOp, tensor::SafeTensors};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// A structure used to retrieve variables, these variables can either come from storage or be
13/// generated via some form of initialization.
14///
15/// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`.
16pub struct VarBuilderArgs<'a, B: Backend> {
17    data: Arc<TensorData<B>>,
18    path: Vec<String>,
19    pub dtype: DType,
20    _phantom: std::marker::PhantomData<&'a B>,
21}
22
23impl<B: Backend> Clone for VarBuilderArgs<'_, B> {
24    fn clone(&self) -> Self {
25        Self {
26            data: self.data.clone(),
27            path: self.path.clone(),
28            dtype: self.dtype,
29            _phantom: self._phantom,
30        }
31    }
32}
33
34/// A simple `VarBuilder`, this is less generic than `VarBuilderArgs` but should cover most common
35/// use cases.
36pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
37
38struct TensorData<B: Backend> {
39    backend: Arc<B>,
40    pub device: Device,
41    pub dtype: DType,
42}
43
44/// A trait that defines how tensor data is retrieved.
45///
46/// Typically this would use disk storage in some specific format, or random initialization.
47/// Note that there is a specialized version of this trait (`SimpleBackend`) that can be used most
48/// of the time. The main restriction is that it doesn't allow for specific args (besides
49/// initialization hints).
50pub trait Backend: Send + Sync {
51    type Hints: Default;
52
53    /// Retrieve a tensor with some target shape.
54    fn get(
55        &self,
56        s: Shape,
57        name: &str,
58        h: Self::Hints,
59        dtype: DType,
60        dev: &Device,
61    ) -> Result<Tensor>;
62
63    /// Retrieve a tensor based on the name.
64    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor>;
65
66    fn contains_tensor(&self, name: &str) -> bool;
67}
68
69pub trait SimpleBackend: Send + Sync {
70    /// Retrieve a tensor based on a target name and shape.
71    fn get(
72        &self,
73        s: Shape,
74        name: &str,
75        h: crate::Init,
76        dtype: DType,
77        dev: &Device,
78    ) -> Result<Tensor>;
79
80    /// Retrieve a tensor based on the name.
81    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor>;
82
83    fn contains_tensor(&self, name: &str) -> bool;
84}
85
86impl Backend for Box<dyn SimpleBackend + '_> {
87    type Hints = crate::Init;
88    fn get(
89        &self,
90        s: Shape,
91        name: &str,
92        h: Self::Hints,
93        dtype: DType,
94        dev: &Device,
95    ) -> Result<Tensor> {
96        self.as_ref().get(s, name, h, dtype, dev)
97    }
98
99    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
100        self.as_ref().get_unchecked(name, dtype, dev)
101    }
102
103    fn contains_tensor(&self, name: &str) -> bool {
104        self.as_ref().contains_tensor(name)
105    }
106}
107
108impl<B: Backend> VarBuilderArgs<'_, B> {
109    pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
110        let data = TensorData {
111            backend: Arc::new(backend),
112            device: dev.clone(),
113            dtype,
114        };
115        Self {
116            data: Arc::new(data),
117            path: vec![],
118            dtype,
119            _phantom: std::marker::PhantomData,
120        }
121    }
122
123    /// Returns the prefix of the `VarBuilder`.
124    pub fn prefix(&self) -> String {
125        self.path.join(".")
126    }
127
128    /// Returns a new `VarBuilder` using the root path.
129    pub fn root(&self) -> Self {
130        Self {
131            data: self.data.clone(),
132            path: vec![],
133            dtype: self.dtype,
134            _phantom: std::marker::PhantomData,
135        }
136    }
137
138    /// Returns a new `VarBuilder` with the prefix set to `prefix`.
139    pub fn set_prefix(&self, prefix: impl ToString) -> Self {
140        Self {
141            data: self.data.clone(),
142            path: vec![prefix.to_string()],
143            dtype: self.dtype,
144            _phantom: std::marker::PhantomData,
145        }
146    }
147
148    /// Return a new `VarBuilder` adding `s` to the current prefix. This can be think of as `cd`
149    /// into a directory.
150    pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
151        let mut path = self.path.clone();
152        path.push(s.to_string());
153        Self {
154            data: self.data.clone(),
155            path,
156            dtype: self.dtype,
157            _phantom: std::marker::PhantomData,
158        }
159    }
160
161    /// Short alias for `push_prefix`.
162    pub fn pp<S: ToString>(&self, s: S) -> Self {
163        self.push_prefix(s)
164    }
165
166    /// The device used by default.
167    pub fn device(&self) -> &Device {
168        &self.data.device
169    }
170
171    /// The dtype used by default.
172    pub fn dtype(&self) -> DType {
173        self.dtype
174    }
175
176    /// Clone the VarBuilder tweaking its dtype
177    pub fn to_dtype(&self, dtype: DType) -> Self {
178        Self {
179            data: self.data.clone(),
180            path: self.path.clone(),
181            dtype,
182            _phantom: std::marker::PhantomData,
183        }
184    }
185
186    fn path(&self, tensor_name: &str) -> String {
187        if self.path.is_empty() {
188            tensor_name.to_string()
189        } else {
190            [&self.path.join("."), tensor_name].join(".")
191        }
192    }
193
194    /// This returns true only if a tensor with the passed in name is available. E.g. when passed
195    /// `a`, true is returned if `prefix.a` exists but false is returned if only `prefix.a.b`
196    /// exists.
197    pub fn contains_tensor(&self, tensor_name: &str) -> bool {
198        let path = self.path(tensor_name);
199        self.data.backend.contains_tensor(&path)
200    }
201
202    /// Retrieve the tensor associated with the given name at the current path.
203    pub fn get_with_hints<S: Into<Shape>>(
204        &self,
205        s: S,
206        name: &str,
207        hints: B::Hints,
208    ) -> Result<Tensor> {
209        self.get_with_hints_dtype(s, name, hints, self.dtype)
210    }
211
212    /// Retrieve the tensor associated with the given name at the current path.
213    pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
214        self.get_with_hints(s, name, Default::default())
215    }
216
217    /// Retrieve the tensor associated with the given name at the current path.
218    pub fn get_unchecked(&self, name: &str) -> Result<Tensor> {
219        self.get_unchecked_dtype(name, self.data.dtype)
220    }
221
222    /// Retrieve the tensor associated with the given name & dtype at the current path.
223    pub fn get_unchecked_dtype(&self, name: &str, dtype: DType) -> Result<Tensor> {
224        let name = self.path(name);
225        self.data
226            .backend
227            .get_unchecked(&name, dtype, &self.data.device)
228    }
229
230    /// Retrieve the tensor associated with the given name & dtype at the current path.
231    pub fn get_with_hints_dtype<S: Into<Shape>>(
232        &self,
233        s: S,
234        name: &str,
235        hints: B::Hints,
236        dtype: DType,
237    ) -> Result<Tensor> {
238        let path = self.path(name);
239        self.data
240            .backend
241            .get(s.into(), &path, hints, dtype, &self.data.device)
242    }
243
244    /// Set the device of the VarBuilder.
245    pub fn set_device(self, device: Device) -> Self {
246        Self {
247            data: Arc::new(TensorData {
248                backend: self.data.backend.clone(),
249                dtype: self.data.dtype,
250                device,
251            }),
252            ..self
253        }
254    }
255
256    /// Set the dtype of the VarBuilder.
257    pub fn set_dtype(self, dtype: DType) -> Self {
258        Self {
259            data: Arc::new(TensorData {
260                backend: self.data.backend.clone(),
261                dtype,
262                device: self.data.device.clone(),
263            }),
264            dtype,
265            ..self
266        }
267    }
268}
269
270struct Zeros;
271
272impl SimpleBackend for Zeros {
273    fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
274        Tensor::zeros(s, dtype, dev)
275    }
276
277    fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
278        candle::bail!(
279            "`Zeros` requires a shape for tensor retrieval, use `get` instead of `get_unchecked`"
280        )
281    }
282
283    fn contains_tensor(&self, _name: &str) -> bool {
284        true
285    }
286}
287
288impl SimpleBackend for HashMap<String, Tensor> {
289    fn get(
290        &self,
291        s: Shape,
292        name: &str,
293        _: crate::Init,
294        dtype: DType,
295        dev: &Device,
296    ) -> Result<Tensor> {
297        let tensor = self
298            .get(name)
299            .ok_or_else(|| {
300                Error::CannotFindTensor {
301                    path: name.to_string(),
302                }
303                .bt()
304            })?
305            .clone();
306        if tensor.shape() != &s {
307            Err(candle::Error::UnexpectedShape {
308                msg: format!("shape mismatch for {name}"),
309                expected: s,
310                got: tensor.shape().clone(),
311            }
312            .bt())?
313        }
314        tensor.to_device(dev)?.to_dtype(dtype)
315    }
316
317    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
318        let tensor = self
319            .get(name)
320            .ok_or_else(|| {
321                Error::CannotFindTensor {
322                    path: name.to_string(),
323                }
324                .bt()
325            })?
326            .clone();
327        tensor.to_device(dev)?.to_dtype(dtype)
328    }
329
330    fn contains_tensor(&self, name: &str) -> bool {
331        self.contains_key(name)
332    }
333}
334
335impl SimpleBackend for VarMap {
336    fn get(
337        &self,
338        s: Shape,
339        name: &str,
340        h: crate::Init,
341        dtype: DType,
342        dev: &Device,
343    ) -> Result<Tensor> {
344        VarMap::get(self, s, name, h, dtype, dev)
345    }
346
347    fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
348        candle::bail!("`get_unchecked` does not make sense for `VarMap`, use `get`.");
349    }
350
351    fn contains_tensor(&self, name: &str) -> bool {
352        self.data().lock().unwrap().contains_key(name)
353    }
354}
355
356#[allow(dead_code)]
357pub struct SafeTensorWithRouting<'a> {
358    routing: HashMap<String, usize>,
359    safetensors: Vec<SafeTensors<'a>>,
360}
361
362impl SimpleBackend for SafeTensorWithRouting<'_> {
363    fn get(
364        &self,
365        s: Shape,
366        path: &str,
367        _: crate::Init,
368        dtype: DType,
369        dev: &Device,
370    ) -> Result<Tensor> {
371        let index = self.routing.get(path).ok_or_else(|| {
372            Error::CannotFindTensor {
373                path: path.to_string(),
374            }
375            .bt()
376        })?;
377        let tensor = self.safetensors[*index]
378            .tensor(path)?
379            .load(dev)?
380            .to_dtype(dtype)?;
381        if tensor.shape() != &s {
382            Err(candle::Error::UnexpectedShape {
383                msg: format!("shape mismatch for {path}"),
384                expected: s,
385                got: tensor.shape().clone(),
386            }
387            .bt())?
388        }
389        Ok(tensor)
390    }
391
392    fn get_unchecked(&self, path: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
393        let index = self.routing.get(path).ok_or_else(|| {
394            Error::CannotFindTensor {
395                path: path.to_string(),
396            }
397            .bt()
398        })?;
399        let tensor = self.safetensors[*index]
400            .tensor(path)?
401            .load(dev)?
402            .to_dtype(dtype)?;
403        Ok(tensor)
404    }
405
406    fn contains_tensor(&self, name: &str) -> bool {
407        self.routing.contains_key(name)
408    }
409}
410
411impl SimpleBackend for candle::npy::NpzTensors {
412    fn get(
413        &self,
414        s: Shape,
415        path: &str,
416        _: crate::Init,
417        dtype: DType,
418        dev: &Device,
419    ) -> Result<Tensor> {
420        let tensor = match self.get(path)? {
421            None => Err(Error::CannotFindTensor {
422                path: path.to_string(),
423            }
424            .bt())?,
425            Some(tensor) => tensor,
426        };
427        let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
428        if tensor.shape() != &s {
429            Err(candle::Error::UnexpectedShape {
430                msg: format!("shape mismatch for {path}"),
431                expected: s,
432                got: tensor.shape().clone(),
433            }
434            .bt())?
435        }
436        Ok(tensor)
437    }
438
439    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
440        let tensor = match self.get(name)? {
441            None => Err(Error::CannotFindTensor {
442                path: name.to_string(),
443            }
444            .bt())?,
445            Some(tensor) => tensor,
446        };
447        let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
448        Ok(tensor)
449    }
450
451    fn contains_tensor(&self, name: &str) -> bool {
452        self.get(name).is_ok_and(|v| v.is_some())
453    }
454}
455
456impl SimpleBackend for candle::pickle::PthTensors {
457    fn get(
458        &self,
459        s: Shape,
460        path: &str,
461        _: crate::Init,
462        dtype: DType,
463        dev: &Device,
464    ) -> Result<Tensor> {
465        let tensor = match self.get(path)? {
466            None => Err(Error::CannotFindTensor {
467                path: path.to_string(),
468            }
469            .bt())?,
470            Some(tensor) => tensor,
471        };
472        let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
473        if tensor.shape() != &s {
474            Err(candle::Error::UnexpectedShape {
475                msg: format!("shape mismatch for {path}"),
476                expected: s,
477                got: tensor.shape().clone(),
478            }
479            .bt())?
480        }
481        Ok(tensor)
482    }
483
484    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
485        let tensor = match self.get(name)? {
486            None => Err(Error::CannotFindTensor {
487                path: name.to_string(),
488            }
489            .bt())?,
490            Some(tensor) => tensor,
491        };
492        let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
493        Ok(tensor)
494    }
495
496    fn contains_tensor(&self, name: &str) -> bool {
497        self.get(name).is_ok_and(|v| v.is_some())
498    }
499}
500
501impl SimpleBackend for candle::safetensors::MmapedSafetensors {
502    fn get(
503        &self,
504        s: Shape,
505        name: &str,
506        _: crate::Init,
507        dtype: DType,
508        dev: &Device,
509    ) -> Result<Tensor> {
510        let tensor = self.load(name, dev)?.to_dtype(dtype)?;
511        if tensor.shape() != &s {
512            Err(candle::Error::UnexpectedShape {
513                msg: format!("shape mismatch for {name}"),
514                expected: s,
515                got: tensor.shape().clone(),
516            }
517            .bt())?
518        }
519        Ok(tensor)
520    }
521
522    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
523        self.load(name, dev)?.to_dtype(dtype)
524    }
525
526    fn contains_tensor(&self, name: &str) -> bool {
527        self.get(name).is_ok()
528    }
529}
530
531impl SimpleBackend for candle::safetensors::BufferedSafetensors {
532    fn get(
533        &self,
534        s: Shape,
535        name: &str,
536        _: crate::Init,
537        dtype: DType,
538        dev: &Device,
539    ) -> Result<Tensor> {
540        let tensor = self.load(name, dev)?.to_dtype(dtype)?;
541        if tensor.shape() != &s {
542            Err(candle::Error::UnexpectedShape {
543                msg: format!("shape mismatch for {name}"),
544                expected: s,
545                got: tensor.shape().clone(),
546            }
547            .bt())?
548        }
549        Ok(tensor)
550    }
551
552    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
553        self.load(name, dev)?.to_dtype(dtype)
554    }
555
556    fn contains_tensor(&self, name: &str) -> bool {
557        self.get(name).is_ok()
558    }
559}
560
561impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
562    fn get(
563        &self,
564        s: Shape,
565        name: &str,
566        _: crate::Init,
567        dtype: DType,
568        dev: &Device,
569    ) -> Result<Tensor> {
570        let tensor = self.load(name, dev)?.to_dtype(dtype)?;
571        if tensor.shape() != &s {
572            Err(candle::Error::UnexpectedShape {
573                msg: format!("shape mismatch for {name}"),
574                expected: s,
575                got: tensor.shape().clone(),
576            }
577            .bt())?
578        }
579        Ok(tensor)
580    }
581
582    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
583        self.load(name, dev)?.to_dtype(dtype)
584    }
585
586    fn contains_tensor(&self, name: &str) -> bool {
587        self.get(name).is_ok()
588    }
589}
590
591impl<'a> VarBuilder<'a> {
592    /// Initializes a `VarBuilder` using a custom backend.
593    ///
594    /// It is preferred to use one of the more specific constructors. This
595    /// constructor is provided to allow downstream users to define their own
596    /// backends.
597    pub fn from_backend(
598        backend: Box<dyn SimpleBackend + 'a>,
599        dtype: DType,
600        device: Device,
601    ) -> Self {
602        let data = TensorData {
603            backend: Arc::new(backend),
604            device,
605            dtype,
606        };
607        Self {
608            data: Arc::new(data),
609            path: vec![],
610            dtype,
611            _phantom: std::marker::PhantomData,
612        }
613    }
614
615    /// Initializes a `VarBuilder` that uses zeros for any tensor.
616    pub fn zeros(dtype: DType, dev: &Device) -> Self {
617        Self::from_backend(Box::new(Zeros), dtype, dev.clone())
618    }
619
620    /// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is
621    /// returned if no tensor is available under the requested path or on shape mismatches.
622    pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
623        Self::from_backend(Box::new(ts), dtype, dev.clone())
624    }
625
626    /// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and
627    /// initialized on new paths, the same tensor is used if the same path is requested multiple
628    /// times. This is commonly used when initializing a model before training.
629    ///
630    /// Note that it is possible to load the tensor values after model creation using the `load`
631    /// method on `varmap`, this can be used to start model training from an existing checkpoint.
632    pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
633        Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone())
634    }
635
636    /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
637    /// files.
638    ///
639    /// # Safety
640    ///
641    /// The unsafe is inherited from [`memmap2::MmapOptions`].
642    pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
643        paths: &[P],
644        dtype: DType,
645        dev: &Device,
646    ) -> Result<Self> {
647        let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
648        Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
649    }
650
651    /// Initializes a `VarBuilder` from a binary buffer in the safetensor format.
652    pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
653        let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
654        Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
655    }
656
657    /// Initializes a `VarBuilder` from a binary slice in the safetensor format.
658    pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
659        let tensors = candle::safetensors::SliceSafetensors::new(data)?;
660        Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
661    }
662
663    /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
664    pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
665        let npz = candle::npy::NpzTensors::new(p)?;
666        Ok(Self::from_backend(Box::new(npz), dtype, dev.clone()))
667    }
668
669    /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
670    pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
671        let pth = candle::pickle::PthTensors::new(p, None)?;
672        Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
673    }
674    /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
675    /// similar to [`from_pth`] but requires a `state_key`.
676    pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
677        p: P,
678        dtype: DType,
679        state_key: &str,
680        dev: &Device,
681    ) -> Result<Self> {
682        let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
683        Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
684    }
685    /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
686    /// passing the new names to the inner VarBuilder.
687    ///
688    /// ```rust
689    /// use candle::{Tensor, DType, Device};
690    ///
691    /// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
692    /// let tensors: std::collections::HashMap<_, _> = [
693    ///     ("foo".to_string(), a),
694    /// ]
695    /// .into_iter()
696    /// .collect();
697    /// let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
698    /// assert!(vb.contains_tensor("foo"));
699    /// assert!(vb.get((2, 3), "foo").is_ok());
700    /// assert!(!vb.contains_tensor("bar"));
701    /// let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
702    /// assert!(vb.contains_tensor("bar"));
703    /// assert!(vb.contains_tensor("foo"));
704    /// assert!(vb.get((2, 3), "bar").is_ok());
705    /// assert!(vb.get((2, 3), "foo").is_ok());
706    /// assert!(!vb.contains_tensor("baz"));
707    /// # Ok::<(), candle::Error>(())
708    /// ```
709    pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self {
710        let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f);
711        self.rename(f)
712    }
713
714    pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self {
715        let dtype = self.dtype();
716        let device = self.device().clone();
717        let path = self.path.clone();
718        let backend = Rename::new(self, renamer);
719        let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
720        let data = TensorData {
721            backend: Arc::new(backend),
722            device,
723            dtype,
724        };
725        Self {
726            data: Arc::new(data),
727            dtype,
728            path,
729            _phantom: std::marker::PhantomData,
730        }
731    }
732}
733
734pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
735
736pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
737
738impl ShardedSafeTensors {
739    /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
740    /// files and make them usable in a sharded way.
741    ///
742    /// # Safety
743    ///
744    /// The unsafe is inherited from [`memmap2::MmapOptions`].
745    pub unsafe fn var_builder<P: AsRef<std::path::Path>>(
746        paths: &[P],
747        dtype: DType,
748        dev: &Device,
749    ) -> Result<ShardedVarBuilder<'static>> {
750        let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
751        let backend = ShardedSafeTensors(tensors);
752        Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
753    }
754}
755
756#[derive(Debug, Clone, Copy, Eq, PartialEq)]
757pub struct Shard {
758    pub dim: usize,
759    pub rank: usize,
760    pub world_size: usize,
761}
762
763impl Default for Shard {
764    fn default() -> Self {
765        Self {
766            dim: 0,
767            rank: 0,
768            world_size: 1,
769        }
770    }
771}
772
773/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
774///
775/// If the tensor is of size (1024, 1024).
776///
777/// `dim` corresponds to the dimension to slice into
778/// `rank` is the rank of the current process
779/// `world_size` is the total number of ranks in the process group
780///
781/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
782/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
783/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
784impl Backend for ShardedSafeTensors {
785    type Hints = Shard;
786
787    fn get(
788        &self,
789        target_shape: Shape, // The size is only checked when the world size is 1.
790        path: &str,
791        h: Self::Hints,
792        dtype: DType,
793        dev: &Device,
794    ) -> Result<Tensor> {
795        if h.world_size == 1 {
796            // There is no sharding to be applied here so we use the default backend to speed
797            // things up.
798            return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev);
799        }
800
801        let Shard {
802            dim,
803            rank,
804            world_size,
805        } = h;
806        let view = self.0.get(path)?;
807        let view_dtype = view.dtype();
808        let mut shape = view.shape().to_vec();
809        let size = shape[dim];
810
811        if size % world_size != 0 {
812            return Err(Error::ShapeMismatchSplit {
813                shape: shape.into(),
814                dim,
815                n_parts: world_size,
816            });
817        }
818        let block_size = size / world_size;
819        let start = rank * block_size;
820        let stop = (rank + 1) * block_size;
821
822        // Everything is expressed in tensor dimension
823        // bytes offsets is handled automatically for safetensors.
824
825        let iterator = if dim == 0 {
826            view.slice(start..stop).map_err(|_| {
827                Error::Msg(format!(
828                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
829                ))
830            })?
831        } else if dim == 1 {
832            view.slice((.., start..stop)).map_err(|_| {
833                Error::Msg(format!(
834                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
835                ))
836            })?
837        } else {
838            candle::bail!("Get sharded on dimensions != 0 or 1")
839        };
840
841        shape[dim] = block_size;
842
843        let view_dtype: DType = view_dtype.try_into()?;
844        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
845        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
846    }
847
848    fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
849        candle::bail!("`get_unchecked` does not make sense for `ShardedSafeTensors`, use `get`.");
850    }
851
852    fn contains_tensor(&self, name: &str) -> bool {
853        self.0.get(name).is_ok()
854    }
855}
856
857/// This traits specifies a way to rename the queried names into names that are stored in an inner
858/// VarBuilder.
859pub trait Renamer {
860    /// This is applied to the name obtained by a name call and the resulting name is passed to the
861    /// inner VarBuilder.
862    fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>;
863}
864
865pub struct Rename<'a, R: Renamer> {
866    inner: VarBuilder<'a>,
867    renamer: R,
868}
869
870impl<R: Renamer + Sync + Send> SimpleBackend for Rename<'_, R> {
871    fn get(
872        &self,
873        s: Shape,
874        name: &str,
875        h: crate::Init,
876        dtype: DType,
877        dev: &Device,
878    ) -> Result<Tensor> {
879        let name = self.renamer.rename(name);
880        self.inner
881            .get_with_hints_dtype(s, &name, h, dtype)?
882            .to_device(dev)
883    }
884
885    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
886        let name = self.renamer.rename(name);
887        self.inner.get_unchecked_dtype(&name, dtype)?.to_device(dev)
888    }
889
890    fn contains_tensor(&self, name: &str) -> bool {
891        let name = self.renamer.rename(name);
892        self.inner.contains_tensor(&name)
893    }
894}
895
896impl<'a, R: Renamer> Rename<'a, R> {
897    pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self {
898        Self { inner, renamer }
899    }
900}
901
902impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> {
903    fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> {
904        std::borrow::Cow::Owned(self(v))
905    }
906}