Skip to main content

numr/autograd/ops/reduce/
statistical.rs

1//! Backward implementations for variance and standard deviation reductions
2
3use crate::autograd::GradFn;
4use crate::autograd::var::Var;
5use crate::autograd::var_ops::var_mul;
6use crate::error::Result;
7use crate::ops::{BinaryOps, ReduceOps, ScalarOps, TensorOps};
8use crate::runtime::{Runtime, RuntimeClient};
9use crate::tensor::{Tensor, TensorId};
10use std::sync::Arc;
11
12use super::common::ensure_contiguous;
13
14// ============================================================================
15// VarBackward
16// ============================================================================
17
18/// Backward for variance reduction: z = var(a, dims, correction)
19///
20/// The gradient of variance is:
21/// dL/da = dL/dz * 2 * (a - mean(a)) / (N - correction)
22///
23/// where N is the number of elements being reduced.
24pub struct VarBackward<R: Runtime> {
25    input_id: TensorId,
26    saved_input: Tensor<R>,
27    dims: Vec<usize>,
28    keepdim: bool,
29    correction: usize,
30    input_grad_fn: Option<Arc<dyn GradFn<R>>>,
31}
32
33impl<R: Runtime> VarBackward<R> {
34    /// Create a new VarBackward
35    pub fn new(
36        input_id: TensorId,
37        input: Tensor<R>,
38        dims: &[usize],
39        keepdim: bool,
40        correction: usize,
41        input_grad_fn: Option<Arc<dyn GradFn<R>>>,
42    ) -> Self {
43        Self {
44            input_id,
45            saved_input: input,
46            dims: dims.to_vec(),
47            keepdim,
48            correction,
49            input_grad_fn,
50        }
51    }
52}
53
54impl<R: Runtime> GradFn<R> for VarBackward<R>
55where
56    R::Client: TensorOps<R> + ScalarOps<R> + ReduceOps<R>,
57{
58    fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
59        let client = R::default_client(grad_output.device());
60
61        let n: usize = self
62            .dims
63            .iter()
64            .map(|&d| self.saved_input.shape()[d])
65            .product();
66        let n_minus_corr = (n - self.correction) as f64;
67
68        let mean = client.mean(&self.saved_input, &self.dims, true)?;
69        let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?);
70
71        let centered = client.sub(&self.saved_input, &mean_broadcast)?;
72
73        let scale = 2.0 / n_minus_corr;
74        let grad_contrib = client.mul_scalar(&centered, scale)?;
75
76        let mut grad = grad_output.clone();
77        if !self.keepdim {
78            let mut sorted_dims = self.dims.clone();
79            sorted_dims.sort();
80            for &dim in &sorted_dims {
81                grad = grad.unsqueeze(dim as isize)?;
82            }
83        }
84        let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?);
85
86        let grad_input = client.mul(&grad_broadcast, &grad_contrib)?;
87
88        Ok(vec![Some(grad_input)])
89    }
90
91    fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
92    where
93        R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + ReduceOps<R>,
94    {
95        let client = R::default_client(grad_output.tensor().device());
96
97        let n: usize = self
98            .dims
99            .iter()
100            .map(|&d| self.saved_input.shape()[d])
101            .product();
102        let n_minus_corr = (n - self.correction) as f64;
103
104        let mean = client.mean(&self.saved_input, &self.dims, true)?;
105        let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?);
106
107        let centered = client.sub(&self.saved_input, &mean_broadcast)?;
108
109        let scale = 2.0 / n_minus_corr;
110        let grad_contrib = client.mul_scalar(&centered, scale)?;
111
112        let mut grad_tensor = grad_output.tensor().clone();
113        if !self.keepdim {
114            let mut sorted_dims = self.dims.clone();
115            sorted_dims.sort();
116            for &dim in &sorted_dims {
117                grad_tensor = grad_tensor.unsqueeze(dim as isize)?;
118            }
119        }
120        let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?);
121
122        let grad_var = Var::new(grad_broadcast, grad_output.requires_grad());
123        let contrib_var = Var::new(grad_contrib, false);
124
125        let grad_input = var_mul(&grad_var, &contrib_var, &client)?;
126
127        Ok(vec![Some(grad_input)])
128    }
129
130    fn inputs(&self) -> &[TensorId] {
131        std::slice::from_ref(&self.input_id)
132    }
133
134    fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
135        vec![self.input_grad_fn.clone()]
136    }
137
138    fn saved_tensors(&self) -> &[Tensor<R>] {
139        std::slice::from_ref(&self.saved_input)
140    }
141
142    fn name(&self) -> &'static str {
143        "VarBackward"
144    }
145}
146
147// ============================================================================
148// StdBackward
149// ============================================================================
150
151/// Backward for standard deviation reduction: z = std(a, dims, correction)
152///
153/// std = sqrt(var), so by chain rule:
154/// dL/da = dL/dz * d(sqrt(var))/dvar * dvar/da
155///       = dL/dz * 1/(2*std) * 2*(a - mean) / (N - correction)
156///       = dL/dz * (a - mean) / ((N - correction) * std)
157pub struct StdBackward<R: Runtime> {
158    input_id: TensorId,
159    saved_input: Tensor<R>,
160    saved_output: Tensor<R>,
161    dims: Vec<usize>,
162    keepdim: bool,
163    correction: usize,
164    input_grad_fn: Option<Arc<dyn GradFn<R>>>,
165}
166
167impl<R: Runtime> StdBackward<R> {
168    /// Create a new StdBackward
169    pub fn new(
170        input_id: TensorId,
171        input: Tensor<R>,
172        output: Tensor<R>,
173        dims: &[usize],
174        keepdim: bool,
175        correction: usize,
176        input_grad_fn: Option<Arc<dyn GradFn<R>>>,
177    ) -> Self {
178        Self {
179            input_id,
180            saved_input: input,
181            saved_output: output,
182            dims: dims.to_vec(),
183            keepdim,
184            correction,
185            input_grad_fn,
186        }
187    }
188}
189
190impl<R: Runtime> GradFn<R> for StdBackward<R>
191where
192    R::Client: TensorOps<R> + ScalarOps<R> + ReduceOps<R>,
193{
194    fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
195        let client = R::default_client(grad_output.device());
196
197        let n: usize = self
198            .dims
199            .iter()
200            .map(|&d| self.saved_input.shape()[d])
201            .product();
202        let n_minus_corr = (n - self.correction) as f64;
203
204        let mean = client.mean(&self.saved_input, &self.dims, true)?;
205        let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?);
206
207        let std_for_broadcast = if self.keepdim {
208            self.saved_output.clone()
209        } else {
210            let mut std_expanded = self.saved_output.clone();
211            let mut sorted_dims = self.dims.clone();
212            sorted_dims.sort();
213            for &dim in &sorted_dims {
214                std_expanded = std_expanded.unsqueeze(dim as isize)?;
215            }
216            std_expanded
217        };
218        let std_broadcast =
219            ensure_contiguous(std_for_broadcast.broadcast_to(self.saved_input.shape())?);
220
221        let centered = client.sub(&self.saved_input, &mean_broadcast)?;
222
223        let denominator = client.mul_scalar(&std_broadcast, n_minus_corr)?;
224        let grad_contrib = client.div(&centered, &denominator)?;
225
226        let mut grad = grad_output.clone();
227        if !self.keepdim {
228            let mut sorted_dims = self.dims.clone();
229            sorted_dims.sort();
230            for &dim in &sorted_dims {
231                grad = grad.unsqueeze(dim as isize)?;
232            }
233        }
234        let grad_broadcast = ensure_contiguous(grad.broadcast_to(self.saved_input.shape())?);
235
236        let grad_input = client.mul(&grad_broadcast, &grad_contrib)?;
237
238        Ok(vec![Some(grad_input)])
239    }
240
241    fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
242    where
243        R::Client: RuntimeClient<R> + TensorOps<R> + ScalarOps<R> + ReduceOps<R>,
244    {
245        let client = R::default_client(grad_output.tensor().device());
246
247        let n: usize = self
248            .dims
249            .iter()
250            .map(|&d| self.saved_input.shape()[d])
251            .product();
252        let n_minus_corr = (n - self.correction) as f64;
253
254        let mean = client.mean(&self.saved_input, &self.dims, true)?;
255        let mean_broadcast = ensure_contiguous(mean.broadcast_to(self.saved_input.shape())?);
256
257        let std_for_broadcast = if self.keepdim {
258            self.saved_output.clone()
259        } else {
260            let mut std_expanded = self.saved_output.clone();
261            let mut sorted_dims = self.dims.clone();
262            sorted_dims.sort();
263            for &dim in &sorted_dims {
264                std_expanded = std_expanded.unsqueeze(dim as isize)?;
265            }
266            std_expanded
267        };
268        let std_broadcast =
269            ensure_contiguous(std_for_broadcast.broadcast_to(self.saved_input.shape())?);
270
271        let centered = client.sub(&self.saved_input, &mean_broadcast)?;
272
273        let denominator = client.mul_scalar(&std_broadcast, n_minus_corr)?;
274        let grad_contrib = client.div(&centered, &denominator)?;
275
276        let mut grad_tensor = grad_output.tensor().clone();
277        if !self.keepdim {
278            let mut sorted_dims = self.dims.clone();
279            sorted_dims.sort();
280            for &dim in &sorted_dims {
281                grad_tensor = grad_tensor.unsqueeze(dim as isize)?;
282            }
283        }
284        let grad_broadcast = ensure_contiguous(grad_tensor.broadcast_to(self.saved_input.shape())?);
285
286        let grad_var = Var::new(grad_broadcast, grad_output.requires_grad());
287        let contrib_var = Var::new(grad_contrib, false);
288
289        let grad_input = var_mul(&grad_var, &contrib_var, &client)?;
290
291        Ok(vec![Some(grad_input)])
292    }
293
294    fn inputs(&self) -> &[TensorId] {
295        std::slice::from_ref(&self.input_id)
296    }
297
298    fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
299        vec![self.input_grad_fn.clone()]
300    }
301
302    fn saved_tensors(&self) -> &[Tensor<R>] {
303        std::slice::from_ref(&self.saved_input)
304    }
305
306    fn name(&self) -> &'static str {
307        "StdBackward"
308    }
309}