candle_core/
variable.rs

1// Variables are wrappers around tensors that can be modified, they are typically used for holding
2// weights and being modified by gradient descent.
3// We do not expose a public way to create variables as this would break the invariant that the
4// tensor within a variable is actually with `is_variable` set to `true`.
5use crate::{DType, Device, Error, Result, Shape, Tensor};
6
7/// A variable is a wrapper around a tensor, however variables can have their content modified
8/// whereas tensors are immutable.
9#[derive(Clone, Debug)]
10pub struct Var(Tensor);
11
12impl std::fmt::Display for Var {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        std::fmt::Display::fmt(&self.0, f)
15    }
16}
17
18impl std::ops::Deref for Var {
19    type Target = Tensor;
20
21    fn deref(&self) -> &Self::Target {
22        self.0.as_ref()
23    }
24}
25
26impl Var {
27    pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
28        let inner = Tensor::zeros_impl(shape, dtype, device, true)?;
29        Ok(Self(inner))
30    }
31
32    pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
33        let inner = Tensor::ones_impl(shape, dtype, device, true)?;
34        Ok(Self(inner))
35    }
36
37    // Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
38    pub fn from_tensor(t: &Tensor) -> Result<Self> {
39        if t.is_variable() {
40            Ok(Self(t.clone()))
41        } else {
42            let inner = t.make_var()?;
43            Ok(Self(inner))
44        }
45    }
46
47    pub fn rand_f64<S: Into<Shape>>(
48        lo: f64,
49        up: f64,
50        s: S,
51        dtype: DType,
52        device: &Device,
53    ) -> Result<Self> {
54        let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
55        Ok(Self(inner))
56    }
57
58    pub fn randn_f64<S: Into<Shape>>(
59        mean: f64,
60        std: f64,
61        s: S,
62        dtype: DType,
63        device: &Device,
64    ) -> Result<Self> {
65        let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
66        Ok(Self(inner))
67    }
68
69    pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
70        lo: T,
71        up: T,
72        s: S,
73        device: &Device,
74    ) -> Result<Self> {
75        let inner = Tensor::rand_impl(lo, up, s, device, true)?;
76        Ok(Self(inner))
77    }
78
79    pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
80        mean: T,
81        std: T,
82        s: S,
83        device: &Device,
84    ) -> Result<Self> {
85        let inner = Tensor::randn_impl(mean, std, s, device, true)?;
86        Ok(Self(inner))
87    }
88
89    /// Creates a new tensor on the specified device using the content and shape of the input.
90    /// This is similar to `new` but the resulting tensor is a variable.
91    pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
92        let shape = array.shape()?;
93        let inner = Tensor::new_impl(array, shape, device, true)?;
94        Ok(Self(inner))
95    }
96
97    pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
98        data: Vec<D>,
99        shape: S,
100        device: &Device,
101    ) -> Result<Self> {
102        let inner = Tensor::from_vec_impl(data, shape, device, true)?;
103        Ok(Self(inner))
104    }
105
106    pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
107        array: &[D],
108        shape: S,
109        device: &Device,
110    ) -> Result<Self> {
111        let inner = Tensor::new_impl(array, shape.into(), device, true)?;
112        Ok(Self(inner))
113    }
114
115    pub fn as_detached_tensor(&self) -> Tensor {
116        self.0.detach()
117    }
118
119    pub fn as_tensor(&self) -> &Tensor {
120        &self.0
121    }
122
123    /// Consumes this `Var` and return the underlying tensor.
124    pub fn into_inner(self) -> Tensor {
125        self.0
126    }
127
128    /// Sets the content of the inner tensor, this does not require a mutable reference as inner
129    /// mutability is used.
130    pub fn set(&self, src: &Tensor) -> Result<()> {
131        if self.same_storage(src) {
132            let msg = "cannot set a variable to a tensor that is derived from its value";
133            Err(Error::CannotSetVar { msg }.bt())?
134        }
135        let (mut dst, layout) = self.storage_mut_and_layout();
136        if !layout.is_contiguous() {
137            let msg = "cannot set a non-contiguous variable";
138            Err(Error::CannotSetVar { msg }.bt())?
139        }
140        let (src, src_l) = src.storage_and_layout();
141        if layout.shape() != src_l.shape() {
142            Err(Error::ShapeMismatchBinaryOp {
143                lhs: layout.shape().clone(),
144                rhs: src_l.shape().clone(),
145                op: "set",
146            }
147            .bt())?
148        }
149        src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?;
150        Ok(())
151    }
152}