candle_nn/
var_builder.rs

1//! A `VarBuilder` for variable retrieval from models
2//!
3//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
4//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized
5//! for training, e.g. using `VarBuilder::from_varmap`.
6use crate::VarMap;
7use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
8use safetensors::{slice::IndexOp, tensor::SafeTensors};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// A structure used to retrieve variables, these variables can either come from storage or be
13/// generated via some form of initialization.
14///
15/// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`.
16pub struct VarBuilderArgs<'a, B: Backend> {
17    data: Arc<TensorData<B>>,
18    path: Vec<String>,
19    pub dtype: DType,
20    _phantom: std::marker::PhantomData<&'a B>,
21}
22
23impl<B: Backend> Clone for VarBuilderArgs<'_, B> {
24    fn clone(&self) -> Self {
25        Self {
26            data: self.data.clone(),
27            path: self.path.clone(),
28            dtype: self.dtype,
29            _phantom: self._phantom,
30        }
31    }
32}
33
34/// A simple `VarBuilder`, this is less generic than `VarBuilderArgs` but should cover most common
35/// use cases.
36pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
37
38struct TensorData<B: Backend> {
39    backend: B,
40    pub device: Device,
41}
42
43/// A trait that defines how tensor data is retrieved.
44///
45/// Typically this would use disk storage in some specific format, or random initialization.
46/// Note that there is a specialized version of this trait (`SimpleBackend`) that can be used most
47/// of the time. The main restriction is that it doesn't allow for specific args (besides
48/// initialization hints).
49pub trait Backend: Send + Sync {
50    type Hints: Default;
51
52    /// Retrieve a tensor with some target shape.
53    fn get(
54        &self,
55        s: Shape,
56        name: &str,
57        h: Self::Hints,
58        dtype: DType,
59        dev: &Device,
60    ) -> Result<Tensor>;
61
62    fn contains_tensor(&self, name: &str) -> bool;
63}
64
65pub trait SimpleBackend: Send + Sync {
66    /// Retrieve a tensor based on a target name and shape.
67    fn get(
68        &self,
69        s: Shape,
70        name: &str,
71        h: crate::Init,
72        dtype: DType,
73        dev: &Device,
74    ) -> Result<Tensor>;
75
76    fn contains_tensor(&self, name: &str) -> bool;
77}
78
79impl Backend for Box<dyn SimpleBackend + '_> {
80    type Hints = crate::Init;
81    fn get(
82        &self,
83        s: Shape,
84        name: &str,
85        h: Self::Hints,
86        dtype: DType,
87        dev: &Device,
88    ) -> Result<Tensor> {
89        self.as_ref().get(s, name, h, dtype, dev)
90    }
91
92    fn contains_tensor(&self, name: &str) -> bool {
93        self.as_ref().contains_tensor(name)
94    }
95}
96
97impl<B: Backend> VarBuilderArgs<'_, B> {
98    pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
99        let data = TensorData {
100            backend,
101            device: dev.clone(),
102        };
103        Self {
104            data: Arc::new(data),
105            path: vec![],
106            dtype,
107            _phantom: std::marker::PhantomData,
108        }
109    }
110
111    /// Returns the prefix of the `VarBuilder`.
112    pub fn prefix(&self) -> String {
113        self.path.join(".")
114    }
115
116    /// Returns a new `VarBuilder` using the root path.
117    pub fn root(&self) -> Self {
118        Self {
119            data: self.data.clone(),
120            path: vec![],
121            dtype: self.dtype,
122            _phantom: std::marker::PhantomData,
123        }
124    }
125
126    /// Returns a new `VarBuilder` with the prefix set to `prefix`.
127    pub fn set_prefix(&self, prefix: impl ToString) -> Self {
128        Self {
129            data: self.data.clone(),
130            path: vec![prefix.to_string()],
131            dtype: self.dtype,
132            _phantom: std::marker::PhantomData,
133        }
134    }
135
136    /// Return a new `VarBuilder` adding `s` to the current prefix. This can be think of as `cd`
137    /// into a directory.
138    pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
139        let mut path = self.path.clone();
140        path.push(s.to_string());
141        Self {
142            data: self.data.clone(),
143            path,
144            dtype: self.dtype,
145            _phantom: std::marker::PhantomData,
146        }
147    }
148
149    /// Short alias for `push_prefix`.
150    pub fn pp<S: ToString>(&self, s: S) -> Self {
151        self.push_prefix(s)
152    }
153
154    /// The device used by default.
155    pub fn device(&self) -> &Device {
156        &self.data.device
157    }
158
159    /// The dtype used by default.
160    pub fn dtype(&self) -> DType {
161        self.dtype
162    }
163
164    /// Clone the VarBuilder tweaking its dtype
165    pub fn to_dtype(&self, dtype: DType) -> Self {
166        Self {
167            data: self.data.clone(),
168            path: self.path.clone(),
169            dtype,
170            _phantom: std::marker::PhantomData,
171        }
172    }
173
174    fn path(&self, tensor_name: &str) -> String {
175        if self.path.is_empty() {
176            tensor_name.to_string()
177        } else {
178            [&self.path.join("."), tensor_name].join(".")
179        }
180    }
181
182    /// This returns true only if a tensor with the passed in name is available. E.g. when passed
183    /// `a`, true is returned if `prefix.a` exists but false is returned if only `prefix.a.b`
184    /// exists.
185    pub fn contains_tensor(&self, tensor_name: &str) -> bool {
186        let path = self.path(tensor_name);
187        self.data.backend.contains_tensor(&path)
188    }
189
190    /// Retrieve the tensor associated with the given name at the current path.
191    pub fn get_with_hints<S: Into<Shape>>(
192        &self,
193        s: S,
194        name: &str,
195        hints: B::Hints,
196    ) -> Result<Tensor> {
197        self.get_with_hints_dtype(s, name, hints, self.dtype)
198    }
199
200    /// Retrieve the tensor associated with the given name at the current path.
201    pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
202        self.get_with_hints(s, name, Default::default())
203    }
204
205    /// Retrieve the tensor associated with the given name & dtype at the current path.
206    pub fn get_with_hints_dtype<S: Into<Shape>>(
207        &self,
208        s: S,
209        name: &str,
210        hints: B::Hints,
211        dtype: DType,
212    ) -> Result<Tensor> {
213        let path = self.path(name);
214        self.data
215            .backend
216            .get(s.into(), &path, hints, dtype, &self.data.device)
217    }
218}
219
220struct Zeros;
221
222impl SimpleBackend for Zeros {
223    fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
224        Tensor::zeros(s, dtype, dev)
225    }
226
227    fn contains_tensor(&self, _name: &str) -> bool {
228        true
229    }
230}
231
232impl SimpleBackend for HashMap<String, Tensor> {
233    fn get(
234        &self,
235        s: Shape,
236        name: &str,
237        _: crate::Init,
238        dtype: DType,
239        dev: &Device,
240    ) -> Result<Tensor> {
241        let tensor = self
242            .get(name)
243            .ok_or_else(|| {
244                Error::CannotFindTensor {
245                    path: name.to_string(),
246                }
247                .bt()
248            })?
249            .clone();
250        if tensor.shape() != &s {
251            Err(candle::Error::UnexpectedShape {
252                msg: format!("shape mismatch for {name}"),
253                expected: s,
254                got: tensor.shape().clone(),
255            }
256            .bt())?
257        }
258        tensor.to_device(dev)?.to_dtype(dtype)
259    }
260
261    fn contains_tensor(&self, name: &str) -> bool {
262        self.contains_key(name)
263    }
264}
265
266impl SimpleBackend for VarMap {
267    fn get(
268        &self,
269        s: Shape,
270        name: &str,
271        h: crate::Init,
272        dtype: DType,
273        dev: &Device,
274    ) -> Result<Tensor> {
275        VarMap::get(self, s, name, h, dtype, dev)
276    }
277
278    fn contains_tensor(&self, name: &str) -> bool {
279        self.data().lock().unwrap().contains_key(name)
280    }
281}
282
283#[allow(dead_code)]
284pub struct SafeTensorWithRouting<'a> {
285    routing: HashMap<String, usize>,
286    safetensors: Vec<SafeTensors<'a>>,
287}
288
289impl SimpleBackend for SafeTensorWithRouting<'_> {
290    fn get(
291        &self,
292        s: Shape,
293        path: &str,
294        _: crate::Init,
295        dtype: DType,
296        dev: &Device,
297    ) -> Result<Tensor> {
298        let index = self.routing.get(path).ok_or_else(|| {
299            Error::CannotFindTensor {
300                path: path.to_string(),
301            }
302            .bt()
303        })?;
304        let tensor = self.safetensors[*index]
305            .tensor(path)?
306            .load(dev)?
307            .to_dtype(dtype)?;
308        if tensor.shape() != &s {
309            Err(candle::Error::UnexpectedShape {
310                msg: format!("shape mismatch for {path}"),
311                expected: s,
312                got: tensor.shape().clone(),
313            }
314            .bt())?
315        }
316        Ok(tensor)
317    }
318
319    fn contains_tensor(&self, name: &str) -> bool {
320        self.routing.contains_key(name)
321    }
322}
323
324impl SimpleBackend for candle::npy::NpzTensors {
325    fn get(
326        &self,
327        s: Shape,
328        path: &str,
329        _: crate::Init,
330        dtype: DType,
331        dev: &Device,
332    ) -> Result<Tensor> {
333        let tensor = match self.get(path)? {
334            None => Err(Error::CannotFindTensor {
335                path: path.to_string(),
336            }
337            .bt())?,
338            Some(tensor) => tensor,
339        };
340        let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
341        if tensor.shape() != &s {
342            Err(candle::Error::UnexpectedShape {
343                msg: format!("shape mismatch for {path}"),
344                expected: s,
345                got: tensor.shape().clone(),
346            }
347            .bt())?
348        }
349        Ok(tensor)
350    }
351
352    fn contains_tensor(&self, name: &str) -> bool {
353        self.get(name).is_ok_and(|v| v.is_some())
354    }
355}
356
357impl SimpleBackend for candle::pickle::PthTensors {
358    fn get(
359        &self,
360        s: Shape,
361        path: &str,
362        _: crate::Init,
363        dtype: DType,
364        dev: &Device,
365    ) -> Result<Tensor> {
366        let tensor = match self.get(path)? {
367            None => Err(Error::CannotFindTensor {
368                path: path.to_string(),
369            }
370            .bt())?,
371            Some(tensor) => tensor,
372        };
373        let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
374        if tensor.shape() != &s {
375            Err(candle::Error::UnexpectedShape {
376                msg: format!("shape mismatch for {path}"),
377                expected: s,
378                got: tensor.shape().clone(),
379            }
380            .bt())?
381        }
382        Ok(tensor)
383    }
384
385    fn contains_tensor(&self, name: &str) -> bool {
386        self.get(name).is_ok_and(|v| v.is_some())
387    }
388}
389
390impl SimpleBackend for candle::safetensors::MmapedSafetensors {
391    fn get(
392        &self,
393        s: Shape,
394        name: &str,
395        _: crate::Init,
396        dtype: DType,
397        dev: &Device,
398    ) -> Result<Tensor> {
399        let tensor = self.load(name, dev)?.to_dtype(dtype)?;
400        if tensor.shape() != &s {
401            Err(candle::Error::UnexpectedShape {
402                msg: format!("shape mismatch for {name}"),
403                expected: s,
404                got: tensor.shape().clone(),
405            }
406            .bt())?
407        }
408        Ok(tensor)
409    }
410
411    fn contains_tensor(&self, name: &str) -> bool {
412        self.get(name).is_ok()
413    }
414}
415
416impl SimpleBackend for candle::safetensors::BufferedSafetensors {
417    fn get(
418        &self,
419        s: Shape,
420        name: &str,
421        _: crate::Init,
422        dtype: DType,
423        dev: &Device,
424    ) -> Result<Tensor> {
425        let tensor = self.load(name, dev)?.to_dtype(dtype)?;
426        if tensor.shape() != &s {
427            Err(candle::Error::UnexpectedShape {
428                msg: format!("shape mismatch for {name}"),
429                expected: s,
430                got: tensor.shape().clone(),
431            }
432            .bt())?
433        }
434        Ok(tensor)
435    }
436
437    fn contains_tensor(&self, name: &str) -> bool {
438        self.get(name).is_ok()
439    }
440}
441
442impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
443    fn get(
444        &self,
445        s: Shape,
446        name: &str,
447        _: crate::Init,
448        dtype: DType,
449        dev: &Device,
450    ) -> Result<Tensor> {
451        let tensor = self.load(name, dev)?.to_dtype(dtype)?;
452        if tensor.shape() != &s {
453            Err(candle::Error::UnexpectedShape {
454                msg: format!("shape mismatch for {name}"),
455                expected: s,
456                got: tensor.shape().clone(),
457            }
458            .bt())?
459        }
460        Ok(tensor)
461    }
462
463    fn contains_tensor(&self, name: &str) -> bool {
464        self.get(name).is_ok()
465    }
466}
467
468impl<'a> VarBuilder<'a> {
469    /// Initializes a `VarBuilder` using a custom backend.
470    ///
471    /// It is preferred to use one of the more specific constructors. This
472    /// constructor is provided to allow downstream users to define their own
473    /// backends.
474    pub fn from_backend(
475        backend: Box<dyn SimpleBackend + 'a>,
476        dtype: DType,
477        device: Device,
478    ) -> Self {
479        let data = TensorData { backend, device };
480        Self {
481            data: Arc::new(data),
482            path: vec![],
483            dtype,
484            _phantom: std::marker::PhantomData,
485        }
486    }
487
488    /// Initializes a `VarBuilder` that uses zeros for any tensor.
489    pub fn zeros(dtype: DType, dev: &Device) -> Self {
490        Self::from_backend(Box::new(Zeros), dtype, dev.clone())
491    }
492
493    /// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is
494    /// returned if no tensor is available under the requested path or on shape mismatches.
495    pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
496        Self::from_backend(Box::new(ts), dtype, dev.clone())
497    }
498
499    /// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and
500    /// initialized on new paths, the same tensor is used if the same path is requested multiple
501    /// times. This is commonly used when initializing a model before training.
502    ///
503    /// Note that it is possible to load the tensor values after model creation using the `load`
504    /// method on `varmap`, this can be used to start model training from an existing checkpoint.
505    pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
506        Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone())
507    }
508
509    /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
510    /// files.
511    ///
512    /// # Safety
513    ///
514    /// The unsafe is inherited from [`memmap2::MmapOptions`].
515    pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
516        paths: &[P],
517        dtype: DType,
518        dev: &Device,
519    ) -> Result<Self> {
520        let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
521        Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
522    }
523
524    /// Initializes a `VarBuilder` from a binary buffer in the safetensor format.
525    pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
526        let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
527        Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
528    }
529
530    /// Initializes a `VarBuilder` from a binary slice in the safetensor format.
531    pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
532        let tensors = candle::safetensors::SliceSafetensors::new(data)?;
533        Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
534    }
535
536    /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
537    pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
538        let npz = candle::npy::NpzTensors::new(p)?;
539        Ok(Self::from_backend(Box::new(npz), dtype, dev.clone()))
540    }
541
542    /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
543    pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
544        let pth = candle::pickle::PthTensors::new(p, None)?;
545        Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
546    }
547    /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
548    /// similar to [`from_pth`] but requires a `state_key`.
549    pub fn from_pth_with_state<P: AsRef<std::path::Path>>(
550        p: P,
551        dtype: DType,
552        state_key: &str,
553        dev: &Device,
554    ) -> Result<Self> {
555        let pth = candle::pickle::PthTensors::new(p, Some(state_key))?;
556        Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
557    }
558    /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
559    /// passing the new names to the inner VarBuilder.
560    ///
561    /// ```rust
562    /// use candle::{Tensor, DType, Device};
563    ///
564    /// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
565    /// let tensors: std::collections::HashMap<_, _> = [
566    ///     ("foo".to_string(), a),
567    /// ]
568    /// .into_iter()
569    /// .collect();
570    /// let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
571    /// assert!(vb.contains_tensor("foo"));
572    /// assert!(vb.get((2, 3), "foo").is_ok());
573    /// assert!(!vb.contains_tensor("bar"));
574    /// let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
575    /// assert!(vb.contains_tensor("bar"));
576    /// assert!(vb.contains_tensor("foo"));
577    /// assert!(vb.get((2, 3), "bar").is_ok());
578    /// assert!(vb.get((2, 3), "foo").is_ok());
579    /// assert!(!vb.contains_tensor("baz"));
580    /// # Ok::<(), candle::Error>(())
581    /// ```
582    pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self {
583        let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f);
584        self.rename(f)
585    }
586
587    pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self {
588        let dtype = self.dtype();
589        let device = self.device().clone();
590        let path = self.path.clone();
591        let backend = Rename::new(self, renamer);
592        let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
593        let data = TensorData { backend, device };
594        Self {
595            data: Arc::new(data),
596            dtype,
597            path,
598            _phantom: std::marker::PhantomData,
599        }
600    }
601}
602
603pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
604
605pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
606
607impl ShardedSafeTensors {
608    /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
609    /// files and make them usable in a sharded way.
610    ///
611    /// # Safety
612    ///
613    /// The unsafe is inherited from [`memmap2::MmapOptions`].
614    pub unsafe fn var_builder<P: AsRef<std::path::Path>>(
615        paths: &[P],
616        dtype: DType,
617        dev: &Device,
618    ) -> Result<ShardedVarBuilder<'static>> {
619        let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
620        let backend = ShardedSafeTensors(tensors);
621        Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
622    }
623}
624
625#[derive(Debug, Clone, Copy, Eq, PartialEq)]
626pub struct Shard {
627    pub dim: usize,
628    pub rank: usize,
629    pub world_size: usize,
630}
631
632impl Default for Shard {
633    fn default() -> Self {
634        Self {
635            dim: 0,
636            rank: 0,
637            world_size: 1,
638        }
639    }
640}
641
642/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
643///
644/// If the tensor is of size (1024, 1024).
645///
646/// `dim` corresponds to the dimension to slice into
647/// `rank` is the rank of the current process
648/// `world_size` is the total number of ranks in the process group
649///
650/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
651/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
652/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
653impl Backend for ShardedSafeTensors {
654    type Hints = Shard;
655
656    fn get(
657        &self,
658        target_shape: Shape, // The size is only checked when the world size is 1.
659        path: &str,
660        h: Self::Hints,
661        dtype: DType,
662        dev: &Device,
663    ) -> Result<Tensor> {
664        if h.world_size == 1 {
665            // There is no sharding to be applied here so we use the default backend to speed
666            // things up.
667            return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev);
668        }
669
670        let Shard {
671            dim,
672            rank,
673            world_size,
674        } = h;
675        let view = self.0.get(path)?;
676        let view_dtype = view.dtype();
677        let mut shape = view.shape().to_vec();
678        let size = shape[dim];
679
680        if size % world_size != 0 {
681            return Err(Error::ShapeMismatchSplit {
682                shape: shape.into(),
683                dim,
684                n_parts: world_size,
685            });
686        }
687        let block_size = size / world_size;
688        let start = rank * block_size;
689        let stop = (rank + 1) * block_size;
690
691        // Everything is expressed in tensor dimension
692        // bytes offsets is handled automatically for safetensors.
693
694        let iterator = if dim == 0 {
695            view.slice(start..stop).map_err(|_| {
696                Error::Msg(format!(
697                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
698                ))
699            })?
700        } else if dim == 1 {
701            view.slice((.., start..stop)).map_err(|_| {
702                Error::Msg(format!(
703                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
704                ))
705            })?
706        } else {
707            candle::bail!("Get sharded on dimensions != 0 or 1")
708        };
709
710        shape[dim] = block_size;
711
712        let view_dtype: DType = view_dtype.try_into()?;
713        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
714        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
715    }
716
717    fn contains_tensor(&self, name: &str) -> bool {
718        self.0.get(name).is_ok()
719    }
720}
721
722/// This traits specifies a way to rename the queried names into names that are stored in an inner
723/// VarBuilder.
724pub trait Renamer {
725    /// This is applied to the name obtained by a name call and the resulting name is passed to the
726    /// inner VarBuilder.
727    fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>;
728}
729
730pub struct Rename<'a, R: Renamer> {
731    inner: VarBuilder<'a>,
732    renamer: R,
733}
734
735impl<R: Renamer + Sync + Send> SimpleBackend for Rename<'_, R> {
736    fn get(
737        &self,
738        s: Shape,
739        name: &str,
740        h: crate::Init,
741        dtype: DType,
742        dev: &Device,
743    ) -> Result<Tensor> {
744        let name = self.renamer.rename(name);
745        self.inner
746            .get_with_hints_dtype(s, &name, h, dtype)?
747            .to_device(dev)
748    }
749
750    fn contains_tensor(&self, name: &str) -> bool {
751        let name = self.renamer.rename(name);
752        self.inner.contains_tensor(&name)
753    }
754}
755
756impl<'a, R: Renamer> Rename<'a, R> {
757    pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self {
758        Self { inner, renamer }
759    }
760}
761
762impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> {
763    fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> {
764        std::borrow::Cow::Owned(self(v))
765    }
766}