numr/autograd/ops/reduce/
statistical.rs1use crate::autograd::GradFn;
4use crate::autograd::var::Var;
5use crate::autograd::var_ops::var_mul;
6use crate::error::Result;
7use crate::ops::{BinaryOps, ReduceOps, ScalarOps, TensorOps};
8use crate::runtime::{Runtime, RuntimeClient};
9use crate::tensor::{Tensor, TensorId};
10use std::sync::Arc;
11
12use super::common::ensure_contiguous;
13
14pub struct VarBackward<R: Runtime> {
25 input_id: TensorId,
26 saved_input: Tensor<R>,
27 dims: Vec<usize>,
28 keepdim: bool,
29 correction: usize,
30 input_grad_fn: Option<Arc<dyn GradFn<R>>>,
31}
32
33impl<R: Runtime> VarBackward<R> {
34 pub fn new(
36 input_id: TensorId,
37 input: Tensor<R>,
38 dims: &[usize],
39 keepdim: bool,
40 correction: usize,
41 input_grad_fn: Option<Arc<dyn GradFn<R>>>,
42 ) -> Self {
43 Self {
44 input_id,
45 saved_input: input,
46 dims: dims.to_vec(),
47 keepdim,
48 correction,
49 input_grad_fn,
50 }
51 }
52}
53
54impl<R: Runtime> GradFn<R> for VarBackward<R>
55where
56 R::Client: TensorOps<R> + ScalarOps<R> + ReduceOps<R>,
57{
58 fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
59 let client = R::default_client(grad_output.device());
60
61 let n: usize = self
62 .dims
63 .iter()
64 .map(|&d| self.saved_input.shape()[d])
65 .product();
66 let n_minus_corr = (n - self.correction) as f64;
67
68 let mean = client.mean(&self.saved_input, &self.dims, true)?;
69 let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?);
70
71 let centered = client.sub(&self.saved_input, &mean_broadcast)?;
72
73 let scale = 2.0 / n_minus_corr;
74 let grad_contrib = client.mul_scalar(¢ered, scale)?;
75
76 let mut grad = grad_output.clone();
77 if !self.keepdim {
78 let mut sorted_dims = self.dims.clone();
79 sorted_dims.sort();
80 for &dim in &sorted_dims {
81 grad = grad.unsqueeze(dim as isize)?;
82 }
83 }
84 let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?);
85
86 let grad_input = client.mul(&grad_broadcast, &grad_contrib)?;
87
88 Ok(vec![Some(grad_input)])
89 }
90
91 fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
92 where
93 R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + ReduceOps<R>,
94 {
95 let client = R::default_client(grad_output.tensor().device());
96
97 let n: usize = self
98 .dims
99 .iter()
100 .map(|&d| self.saved_input.shape()[d])
101 .product();
102 let n_minus_corr = (n - self.correction) as f64;
103
104 let mean = client.mean(&self.saved_input, &self.dims, true)?;
105 let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?);
106
107 let centered = client.sub(&self.saved_input, &mean_broadcast)?;
108
109 let scale = 2.0 / n_minus_corr;
110 let grad_contrib = client.mul_scalar(¢ered, scale)?;
111
112 let mut grad_tensor = grad_output.tensor().clone();
113 if !self.keepdim {
114 let mut sorted_dims = self.dims.clone();
115 sorted_dims.sort();
116 for &dim in &sorted_dims {
117 grad_tensor = grad_tensor.unsqueeze(dim as isize)?;
118 }
119 }
120 let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?);
121
122 let grad_var = Var::new(grad_broadcast, grad_output.requires_grad());
123 let contrib_var = Var::new(grad_contrib, false);
124
125 let grad_input = var_mul(&grad_var, &contrib_var, &client)?;
126
127 Ok(vec![Some(grad_input)])
128 }
129
130 fn inputs(&self) -> &[TensorId] {
131 std::slice::from_ref(&self.input_id)
132 }
133
134 fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
135 vec![self.input_grad_fn.clone()]
136 }
137
138 fn saved_tensors(&self) -> &[Tensor<R>] {
139 std::slice::from_ref(&self.saved_input)
140 }
141
142 fn name(&self) -> &'static str {
143 "VarBackward"
144 }
145}
146
147pub struct StdBackward<R: Runtime> {
158 input_id: TensorId,
159 saved_input: Tensor<R>,
160 saved_output: Tensor<R>,
161 dims: Vec<usize>,
162 keepdim: bool,
163 correction: usize,
164 input_grad_fn: Option<Arc<dyn GradFn<R>>>,
165}
166
167impl<R: Runtime> StdBackward<R> {
168 pub fn new(
170 input_id: TensorId,
171 input: Tensor<R>,
172 output: Tensor<R>,
173 dims: &[usize],
174 keepdim: bool,
175 correction: usize,
176 input_grad_fn: Option<Arc<dyn GradFn<R>>>,
177 ) -> Self {
178 Self {
179 input_id,
180 saved_input: input,
181 saved_output: output,
182 dims: dims.to_vec(),
183 keepdim,
184 correction,
185 input_grad_fn,
186 }
187 }
188}
189
190impl<R: Runtime> GradFn<R> for StdBackward<R>
191where
192 R::Client: TensorOps<R> + ScalarOps<R> + ReduceOps<R>,
193{
194 fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
195 let client = R::default_client(grad_output.device());
196
197 let n: usize = self
198 .dims
199 .iter()
200 .map(|&d| self.saved_input.shape()[d])
201 .product();
202 let n_minus_corr = (n - self.correction) as f64;
203
204 let mean = client.mean(&self.saved_input, &self.dims, true)?;
205 let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?);
206
207 let std_for_broadcast = if self.keepdim {
208 self.saved_output.clone()
209 } else {
210 let mut std_expanded = self.saved_output.clone();
211 let mut sorted_dims = self.dims.clone();
212 sorted_dims.sort();
213 for &dim in &sorted_dims {
214 std_expanded = std_expanded.unsqueeze(dim as isize)?;
215 }
216 std_expanded
217 };
218 let std_broadcast =
219 ensure_contiguous(std_for_broadcast.broadcast_to(self.saved_input.shape())?);
220
221 let centered = client.sub(&self.saved_input, &mean_broadcast)?;
222
223 let denominator = client.mul_scalar(&std_broadcast, n_minus_corr)?;
224 let grad_contrib = client.div(¢ered, &denominator)?;
225
226 let mut grad = grad_output.clone();
227 if !self.keepdim {
228 let mut sorted_dims = self.dims.clone();
229 sorted_dims.sort();
230 for &dim in &sorted_dims {
231 grad = grad.unsqueeze(dim as isize)?;
232 }
233 }
234 let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?);
235
236 let grad_input = client.mul(&grad_broadcast, &grad_contrib)?;
237
238 Ok(vec![Some(grad_input)])
239 }
240
241 fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
242 where
243 R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + ReduceOps<R>,
244 {
245 let client = R::default_client(grad_output.tensor().device());
246
247 let n: usize = self
248 .dims
249 .iter()
250 .map(|&d| self.saved_input.shape()[d])
251 .product();
252 let n_minus_corr = (n - self.correction) as f64;
253
254 let mean = client.mean(&self.saved_input, &self.dims, true)?;
255 let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?);
256
257 let std_for_broadcast = if self.keepdim {
258 self.saved_output.clone()
259 } else {
260 let mut std_expanded = self.saved_output.clone();
261 let mut sorted_dims = self.dims.clone();
262 sorted_dims.sort();
263 for &dim in &sorted_dims {
264 std_expanded = std_expanded.unsqueeze(dim as isize)?;
265 }
266 std_expanded
267 };
268 let std_broadcast =
269 ensure_contiguous(std_for_broadcast.broadcast_to(self.saved_input.shape())?);
270
271 let centered = client.sub(&self.saved_input, &mean_broadcast)?;
272
273 let denominator = client.mul_scalar(&std_broadcast, n_minus_corr)?;
274 let grad_contrib = client.div(¢ered, &denominator)?;
275
276 let mut grad_tensor = grad_output.tensor().clone();
277 if !self.keepdim {
278 let mut sorted_dims = self.dims.clone();
279 sorted_dims.sort();
280 for &dim in &sorted_dims {
281 grad_tensor = grad_tensor.unsqueeze(dim as isize)?;
282 }
283 }
284 let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?);
285
286 let grad_var = Var::new(grad_broadcast, grad_output.requires_grad());
287 let contrib_var = Var::new(grad_contrib, false);
288
289 let grad_input = var_mul(&grad_var, &contrib_var, &client)?;
290
291 Ok(vec![Some(grad_input)])
292 }
293
294 fn inputs(&self) -> &[TensorId] {
295 std::slice::from_ref(&self.input_id)
296 }
297
298 fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
299 vec![self.input_grad_fn.clone()]
300 }
301
302 fn saved_tensors(&self) -> &[Tensor<R>] {
303 std::slice::from_ref(&self.saved_input)
304 }
305
306 fn name(&self) -> &'static str {
307 "StdBackward"
308 }
309}