1use crate::{DType, Device, Error, Result, Shape, Tensor};
6
7#[derive(Clone, Debug)]
10pub struct Var(Tensor);
11
12impl std::fmt::Display for Var {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 std::fmt::Display::fmt(&self.0, f)
15 }
16}
17
18impl std::ops::Deref for Var {
19 type Target = Tensor;
20
21 fn deref(&self) -> &Self::Target {
22 self.0.as_ref()
23 }
24}
25
26impl Var {
27 pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
28 let inner = Tensor::zeros_impl(shape, dtype, device, true)?;
29 Ok(Self(inner))
30 }
31
32 pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
33 let inner = Tensor::ones_impl(shape, dtype, device, true)?;
34 Ok(Self(inner))
35 }
36
37 pub fn from_tensor(t: &Tensor) -> Result<Self> {
39 if t.is_variable() {
40 Ok(Self(t.clone()))
41 } else {
42 let inner = t.make_var()?;
43 Ok(Self(inner))
44 }
45 }
46
47 pub fn rand_f64<S: Into<Shape>>(
48 lo: f64,
49 up: f64,
50 s: S,
51 dtype: DType,
52 device: &Device,
53 ) -> Result<Self> {
54 let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
55 Ok(Self(inner))
56 }
57
58 pub fn randn_f64<S: Into<Shape>>(
59 mean: f64,
60 std: f64,
61 s: S,
62 dtype: DType,
63 device: &Device,
64 ) -> Result<Self> {
65 let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
66 Ok(Self(inner))
67 }
68
69 pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
70 lo: T,
71 up: T,
72 s: S,
73 device: &Device,
74 ) -> Result<Self> {
75 let inner = Tensor::rand_impl(lo, up, s, device, true)?;
76 Ok(Self(inner))
77 }
78
79 pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
80 mean: T,
81 std: T,
82 s: S,
83 device: &Device,
84 ) -> Result<Self> {
85 let inner = Tensor::randn_impl(mean, std, s, device, true)?;
86 Ok(Self(inner))
87 }
88
89 pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
92 let shape = array.shape()?;
93 let inner = Tensor::new_impl(array, shape, device, true)?;
94 Ok(Self(inner))
95 }
96
97 pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
98 data: Vec<D>,
99 shape: S,
100 device: &Device,
101 ) -> Result<Self> {
102 let inner = Tensor::from_vec_impl(data, shape, device, true)?;
103 Ok(Self(inner))
104 }
105
106 pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
107 array: &[D],
108 shape: S,
109 device: &Device,
110 ) -> Result<Self> {
111 let inner = Tensor::new_impl(array, shape.into(), device, true)?;
112 Ok(Self(inner))
113 }
114
115 pub fn as_detached_tensor(&self) -> Tensor {
116 self.0.detach()
117 }
118
119 pub fn as_tensor(&self) -> &Tensor {
120 &self.0
121 }
122
123 pub fn into_inner(self) -> Tensor {
125 self.0
126 }
127
128 pub fn set(&self, src: &Tensor) -> Result<()> {
131 if self.same_storage(src) {
132 let msg = "cannot set a variable to a tensor that is derived from its value";
133 Err(Error::CannotSetVar { msg }.bt())?
134 }
135 let (mut dst, layout) = self.storage_mut_and_layout();
136 if !layout.is_contiguous() {
137 let msg = "cannot set a non-contiguous variable";
138 Err(Error::CannotSetVar { msg }.bt())?
139 }
140 let (src, src_l) = src.storage_and_layout();
141 if layout.shape() != src_l.shape() {
142 Err(Error::ShapeMismatchBinaryOp {
143 lhs: layout.shape().clone(),
144 rhs: src_l.shape().clone(),
145 op: "set",
146 }
147 .bt())?
148 }
149 src.copy_strided_src(&mut dst, layout.start_offset(), src_l)?;
150 Ok(())
151 }
152}