Skip to main content

lumen_core/grad/
backprop.rs

1use std::collections::HashMap;
2use crate::{FloatDType, Tensor, TensorId};
3
4use super::{BinaryOp, GradStore, Op, ReduceOp, UnaryOp};
5
6impl<T: FloatDType> Tensor<T> {
7
8    pub fn backward(&self) -> crate::Result<GradStore<T>> {
9        let _guard = crate::NoGradGuard::new();
10
11        let sorted_nodes = self.sorted_nodes();
12        let mut grads = GradStore::new();
13        grads.insert(self, self.ones_like()?);
14
15        for node in sorted_nodes.iter() {
16            match node.op() {
17                None => {
18                    assert!(node.is_leaf());
19                    continue
20                }
21                Some(op) => {
22                    let grad = grads
23                        .remove(node)
24                        .expect("grad not populated");
25
26                    match op {
27                        //=========================================================================================//
28                        //           Binary
29                        //=========================================================================================//
30                        Op::Binary(lhs, rhs, BinaryOp::Add) => {
31                            let lhs_sum_grad = grads.or_insert(lhs)?;
32                            lhs_sum_grad.add_(&grad)?;
33                            let rhs_sum_grad = grads.or_insert(rhs)?;
34                            rhs_sum_grad.add_(&grad)?;
35                        }
36                        Op::Binary(lhs, rhs, BinaryOp::Sub) => {
37                            let lhs_sum_grad = grads.or_insert(lhs)?;
38                            lhs_sum_grad.add_(&grad)?;
39                            let rhs_sum_grad = grads.or_insert(rhs)?;
40                            rhs_sum_grad.sub_(&grad)?;
41                        }
42                        Op::Binary(lhs, rhs, BinaryOp::Mul) => {
43                            let lhs_grad = grad.mul(rhs)?;
44                            let lhs_sum_grad = grads.or_insert(lhs)?;
45                            lhs_sum_grad.add_(&lhs_grad)?;
46                            
47                            let rhs_grad = grad.mul(lhs)?;
48                            let rhs_sum_grad = grads.or_insert(rhs)?;
49                            rhs_sum_grad.add_(&rhs_grad)?;
50                        }
51                        Op::Binary(lhs, rhs, BinaryOp::Div) => {
52                            let lhs_grad = grad.div(rhs)?;
53                            let lhs_sum_grad = grads.or_insert(lhs)?;
54                            lhs_sum_grad.add_(&lhs_grad)?; 
55                            
56                            let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
57                            let rhs_sum_grad = grads.or_insert(rhs)?;
58                            rhs_sum_grad.sub_(&rhs_grad)?; 
59                        }
60                        Op::Binary(lhs, rhs, BinaryOp::Minimum)
61                        | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
62                            let mask_lhs = (*node).eq(lhs)?.cast()?;
63                            let mask_rhs = (*node).eq(rhs)?.cast()?;
64    
65                            // If both masks are 1 one the same point, we want to scale the
66                            // gradient by 0.5 rather than 1.
67                            let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + T::one()))?;
68                            let lhs_sum_grad = grads.or_insert(lhs)?;
69                            lhs_sum_grad.add_(&lhs_grad)?;
70    
71                            let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + T::one()))?;
72                            let rhs_sum_grad = grads.or_insert(rhs)?;
73                            rhs_sum_grad.add_(&rhs_grad)?;
74                        }
75                        
76                        //=========================================================================================//
77                        //           BinaryScalarRhs
78                        //=========================================================================================//
79                        Op::BinaryScalarRhs(lhs, _, BinaryOp::Add) => {
80                            // y = x + c => dy/dx = 1
81                            let lhs_sum_grad = grads.or_insert(lhs)?;
82                            lhs_sum_grad.add_(&grad)?;
83                        }
84                        Op::BinaryScalarRhs(lhs, _, BinaryOp::Sub) => {
85                            // y = x - c => dy/dx = 1
86                            let lhs_sum_grad = grads.or_insert(lhs)?;
87                            lhs_sum_grad.add_(&grad)?;
88                        }
89                        Op::BinaryScalarRhs(lhs, rhs, BinaryOp::Mul) => {
90                            // y = x * c => dy/dx = c
91                            let lhs_grad = grad.mul_scalar(*rhs)?;
92                            let lhs_sum_grad = grads.or_insert(lhs)?;
93                            lhs_sum_grad.add_(&lhs_grad)?;
94                        }
95                        Op::BinaryScalarRhs(lhs, rhs, BinaryOp::Div) => {
96                            // y = x / c => dy/dx = 1/c
97                            let lhs_grad = grad.div_scalar(*rhs)?;
98                            let lhs_sum_grad = grads.or_insert(lhs)?;
99                            lhs_sum_grad.add_(&lhs_grad)?;
100                        }
101                        Op::BinaryScalarRhs(lhs, rhs, BinaryOp::Maximum) |
102                        Op::BinaryScalarRhs(lhs, rhs, BinaryOp::Minimum) => {
103                            let mask_lhs = (*node).eq(lhs)?.cast()?;                            
104                            let mask_rhs = (*node).eq(*rhs)?.cast()?;
105                            let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + T::one()))?;
106                            let lhs_sum_grad = grads.or_insert(lhs)?;
107                            lhs_sum_grad.add_(&lhs_grad)?;
108                        }
109
110                        //=========================================================================================//
111                        //           BinaryScalarLhs
112                        //=========================================================================================//
113                        Op::BinaryScalarLhs(_, rhs, BinaryOp::Add) => {
114                            // y = c + x => dy/dx = 1
115                            let rhs_sum_grad = grads.or_insert(rhs)?;
116                            rhs_sum_grad.add_(&grad)?;
117                        }
118                        Op::BinaryScalarLhs(_, rhs, BinaryOp::Sub) => {
119                            // y = c - x => dy/dx = -1
120                            let rhs_sum_grad = grads.or_insert(rhs)?;
121                            rhs_sum_grad.sub_(&grad)?; 
122                        }
123                        Op::BinaryScalarLhs(lhs, rhs, BinaryOp::Mul) => {
124                            // y = c * x => dy/dx = c
125                            let rhs_grad = grad.mul_scalar(*lhs)?;
126                            let rhs_sum_grad = grads.or_insert(rhs)?;
127                            rhs_sum_grad.add_(&rhs_grad)?;
128                        }
129                        Op::BinaryScalarLhs(lhs, rhs, BinaryOp::Div) => {
130                            // y = c / x = c * x^(-1)
131                            // dy/dx = -c * x^(-2) = -c / (x^2)
132                            // grad_input = grad * (-c / x^2)
133                            let numerator = grad.mul_scalar(-*lhs)?;                            
134                            let denominator = rhs.mul(rhs)?;                          
135                            let rhs_grad = numerator.div(&denominator)?;
136                            
137                            let rhs_sum_grad = grads.or_insert(rhs)?;
138                            rhs_sum_grad.add_(&rhs_grad)?;
139                        }
140                        Op::BinaryScalarLhs(lhs, rhs, BinaryOp::Maximum) |
141                        Op::BinaryScalarLhs(lhs, rhs, BinaryOp::Minimum) => {
142                            let mask_lhs = (*node).eq(*lhs)?.cast()?;
143                            let mask_rhs = (*node).eq(rhs)?.cast()?;
144                            let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + T::one()))?;                            
145                            let rhs_sum_grad = grads.or_insert(rhs)?;
146                            rhs_sum_grad.add_(&rhs_grad)?;
147                        }
148
149                        //=========================================================================================//
150                        //           Unary
151                        //=========================================================================================//
152                        Op::Unary(_, UnaryOp::Ceil) => Err(crate::Error::BackwardNotSupported("ceil"))?,
153                        Op::Unary(_, UnaryOp::Floor) => Err(crate::Error::BackwardNotSupported("floor"))?,
154                        Op::Unary(_, UnaryOp::Round) => Err(crate::Error::BackwardNotSupported("round"))?,                        
155                        Op::Unary(_, UnaryOp::Sign) => Err(crate::Error::BackwardNotSupported("sign"))?,
156                        Op::Unary(arg, UnaryOp::Exp) => {
157                            let sum_grad = grads.or_insert(arg)?;
158                            sum_grad.add_(&(&grad * *node))?;
159                        }
160                        Op::Unary(arg, UnaryOp::Ln) => {
161                            let sum_grad = grads.or_insert(arg)?;
162                            sum_grad.add_(&(grad / arg))?;
163                        }
164                        Op::Unary(arg, UnaryOp::Sin) => {
165                            let sum_grad = grads.or_insert(arg)?;
166                            sum_grad.add_(&(&grad * arg.cos()?))?;
167                        }
168                        Op::Unary(arg, UnaryOp::Cos) => {
169                            let sum_grad = grads.or_insert(arg)?;
170                            // y = cos(x) -> y' = -sin(x) -> grad = grad * -sin(x) -> grad -= grad * sin(x)
171                            sum_grad.sub_(&(&grad * arg.sin()?))?;
172                        }
173                        Op::Unary(arg, UnaryOp::Tanh) => {
174                            let sum_grad = grads.or_insert(arg)?;
175                            let minus_dtanh = node.sqr()? - T::one();
176                            // y = tanh(x) -> y' = 1 - tanh^2(x) = 1 - y^2 = -(y^2 - 1)
177                            sum_grad.sub_(&(&grad * &minus_dtanh))?;
178                        }
179                        Op::Unary(arg, UnaryOp::Sqr) => {
180                            let arg_grad = arg.mul(&grad)?.affine(T::two(), T::zero())?;
181                            let sum_grad = grads.or_insert(arg)?;
182                            sum_grad.add_(&arg_grad)?;
183                        }
184                        Op::Unary(arg, UnaryOp::Sqrt) => {
185                            let arg_grad = grad.div(*node)?.affine(T::half(), T::zero())?;
186                            let sum_grad = grads.or_insert(arg)?;
187                            sum_grad.add_(&arg_grad)?;
188                        }
189                        Op::Unary(arg, UnaryOp::Abs) => {
190                            let sum_grad = grads.or_insert(arg)?;
191                            let ones = arg.ones_like()?;
192                            let abs_grad = arg.ge(&arg.zeros_like()?)?.if_else(&ones, ones.neg()?)?;
193                            sum_grad.add_(&(&grad * abs_grad))?;
194                        }
195                        Op::Unary(arg, UnaryOp::Neg) => {
196                            let sum_grad = grads.or_insert(arg)?;
197                            // dy/dx = -1 -> sub(grad)
198                            sum_grad.sub_(&grad)?;
199                        }
200                        Op::Unary(arg, UnaryOp::Recip) => {
201                            let sum_grad = grads.or_insert(arg)?;
202                            let grad = grad / arg.sqr()?;
203                            sum_grad.sub_(&grad)?;
204                        }
205                        Op::Unary(arg, UnaryOp::Gelu) => {
206                            let sum_grad = grads.or_insert(arg)?;
207                            let cube = arg.pow(T::from_f64(3.))?;
208                            let tanh = (&cube * T::from_f64(0.0356774) + (arg * T::from_f64(0.797885))).tanh()?;
209                            let gelu_grad = 
210                                &tanh / T::two()
211                                + (cube * T::from_f64(0.0535161) + arg * T::from_f64(0.398942)) * (tanh.pow(T::two())?.neg()? + T::one())
212                                + T::half();
213                            sum_grad.add_(&(&grad * gelu_grad))?;
214                        }
215                        Op::Unary(arg, UnaryOp::Erf) => {
216                            let sum_grad = grads.or_insert(arg)?;
217                            // d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
218                            let erf_grad = arg.sqr()?.neg()?.exp()? * (T::two() / T::pi().sqrt());
219                            sum_grad.add_(&(&grad * erf_grad))?;
220                        }
221                        Op::Unary(arg, UnaryOp::GeluErf) => {
222                            let sum_grad = grads.or_insert(arg)?;
223                            // d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
224                            let neg_half_square = arg.sqr()?.neg()? / T::two();
225                            let scaled_exp_arg = T::from_f64(0.398942) * neg_half_square.exp()? * arg;
226                            let arg_scaled_sqrt = arg / T::two().sqrt();
227                            let erf_scaled_sqrt = arg_scaled_sqrt.erf()? / T::two();
228                            let gelu_erf_grad = scaled_exp_arg + erf_scaled_sqrt + T::half();
229                            sum_grad.add_(&(&grad * gelu_erf_grad))?;
230                        }
231                        Op::Unary(arg, UnaryOp::Relu) => {
232                            let sum_grad = grads.or_insert(arg)?;
233                            let relu_grad = arg.ge(&arg.zeros_like()?)?.cast::<T>()?;
234                            sum_grad.add_(&(&grad * relu_grad))?;
235                        }
236                        Op::Unary(arg, UnaryOp::Silu) => {
237                            let sum_grad = grads.or_insert(arg)?;
238                            // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
239                            let sigmoid_arg = (arg.neg()?.exp()? + T::one()).recip()?;
240                            let silu_grad = &sigmoid_arg * (T::one() - *node) + *node;
241                            sum_grad.add_(&(&grad * silu_grad))?;
242                        }
243                        Op::Unary(arg, UnaryOp::Sigmoid) => {
244                            let sum_grad = grads.or_insert(arg)?;
245                            // y = sigmoid(x) = *node
246                            let local_deriv = *node * (T::one() - *node);                            
247                            sum_grad.add_(&(&grad * local_deriv))?;
248                        }
249                        Op::Unary(arg, UnaryOp::LeakyRelu(negative_slope)) => {
250                            let sum_grad = grads.or_insert(arg)?;
251                            let mask = arg.ge(&arg.zeros_like()?)?.cast::<T>()?;
252                        
253                            let ones = mask.ones_like()?;
254                            let inv_mask = ones.sub(&mask)?; 
255                        
256                            let slope_part = inv_mask.mul_scalar(*negative_slope)?;
257                            let local_deriv = mask.add(&slope_part)?;
258                        
259                            sum_grad.add_(&(&grad * local_deriv))?;
260                        }
261
262                        //=========================================================================================//
263                        //           Matmul
264                        //=========================================================================================//
265                        Op::Matmul(lhs, rhs) => {    
266                            let lhs_grad = grad.matmul(&rhs.transpose_last()?)?;
267                            let lhs_sum_grad = grads.or_insert(lhs)?;
268                            lhs_sum_grad.add_(&lhs_grad)?;
269    
270                            let rhs_grad = lhs.transpose_last()?.matmul(&grad)?;
271                            let rhs_sum_grad = grads.or_insert(rhs)?;
272                            rhs_sum_grad.add_(&rhs_grad)?;
273                        }
274
275                        //=========================================================================================//
276                        //           Pow
277                        //=========================================================================================//
278                        Op::Pow(arg, e) => {
279                            let arg_grad = &(grad * arg.pow(*e - T::one())?) * *e;
280                            let sum_grad = grads.or_insert(arg)?;
281                            sum_grad.add_(&arg_grad)?;
282                        }
283
284                        //=========================================================================================//
285                        //           Reduce
286                        //=========================================================================================//
287                        Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
288                            let grad = Self::broadcast_back(arg, &grad, reduced_dims)?;
289                            let sum_grad = grads.or_insert(arg)?;
290                            sum_grad.add_(&grad)?;
291                        }
292                        Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
293                            let node = Self::broadcast_back(arg, node, reduced_dims)?;
294                            let grad = Self::broadcast_back(arg, &grad, reduced_dims)?;
295                            let grad = node.eq(arg)?.cast()?.mul(&grad)?;
296                            let sum_grad = grads.or_insert(arg)?;
297                            sum_grad.add_(&grad.broadcast_as(sum_grad.dims())?)?;
298                        }
299                        Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
300                            let node = Self::broadcast_back(arg, node, reduced_dims)?;
301                            let grad = Self::broadcast_back(arg, &grad, reduced_dims)?;
302                            let grad = node.eq(arg)?.cast()?.mul(&grad)?;
303                            let sum_grad = grads.or_insert(arg)?;
304                            sum_grad.add_(&grad.broadcast_as(sum_grad.dims())?)?;
305                        }
306                        Op::Reduce(arg, ReduceOp::Mean, reduced_dims) => {
307                            let grad_output = Self::broadcast_back(arg, &grad, reduced_dims)?;
308                            let n = arg.element_count() / node.element_count();
309                            
310                            // grad_input = grad_output / n
311                            let grad_input = grad_output / T::from_usize(n);
312                            
313                            let sum_grad = grads.or_insert(arg)?;
314                            sum_grad.add_(&grad_input)?;
315                        }                        
316
317                        //=========================================================================================//
318                        //           Broadcast
319                        //=========================================================================================//
320                        Op::Broadcast(arg) => {
321                            let arg_dims = arg.dims();
322                            let node_dims = node.dims();
323                            let left_dims = node_dims.len() - arg_dims.len();
324                            let mut sum_dims: Vec<usize> = (0..left_dims).collect();
325                            for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
326                                .iter()
327                                .zip(arg_dims.iter())
328                                .enumerate()
329                            {
330                                if node_dim != arg_dim {
331                                    sum_dims.push(dim + left_dims)
332                                }
333                            }
334    
335                            let mut arg_grad = grad;
336                            for &dim in sum_dims.iter() {
337                                arg_grad = arg_grad.sum_keepdim(dim)?;
338                            }
339
340                            for _i in 0..left_dims {
341                                arg_grad = arg_grad.squeeze(0)?
342                            }
343                            let sum_grad = grads.or_insert(arg)?;
344                            sum_grad.add_(&arg_grad.broadcast_as(sum_grad.dims())?)?;
345                        }
346
347                        //=========================================================================================//
348                        //           Narrow
349                        //=========================================================================================//
350                        &Op::Narrow(ref arg, dim, start_idx, len) => {
351                            let arg_dims = arg.dims();
352                            let left_pad = if start_idx == 0 {
353                                None
354                            } else {
355                                let mut dims = arg_dims.to_vec();
356                                dims[dim] = start_idx;
357                                Some(Tensor::zeros(dims)?)
358                            };
359                            let right_pad = arg_dims[dim] - start_idx - len;
360                            let right_pad = if right_pad == 0 {
361                                None
362                            } else {
363                                let mut dims = arg_dims.to_vec();
364                                dims[dim] = right_pad;
365                                Some(Tensor::zeros(dims)?)
366                            };
367                            let arg_grad = match (left_pad, right_pad) {
368                                (None, None) => grad,
369                                (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
370                                (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
371                                (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
372                            };
373                            let sum_grad = grads.or_insert(arg)?;
374                            sum_grad.add_(&arg_grad)?;
375                        }
376
377                        //=========================================================================================//
378                        //           Slice
379                        //=========================================================================================//
380                        &Op::Slice(ref arg, dim, start, _end, step) => {
381                            let arg_dims = arg.dims();
382                            
383                            let body_grad = if step == 1 {
384                                // Narrow
385                                grad
386                            } else {
387                                let grad_len = grad.dims()[dim];
388                                let span_len = if grad_len > 0 { (grad_len - 1) * step + 1 } else { 0 };
389                                
390                                let mut unsqueezed_shape = grad.dims().to_vec();
391                                unsqueezed_shape.insert(dim + 1, 1);
392                                let grad_unsqueezed = grad.reshape(&unsqueezed_shape)?;
393                                
394                                let mut zeros_shape = unsqueezed_shape.clone();
395                                zeros_shape[dim + 1] = step - 1;
396                                let zeros_gap = Tensor::zeros(zeros_shape)?;
397                                
398                                let dilated = Tensor::cat(&[&grad_unsqueezed, &zeros_gap], dim + 1)?;
399                                
400                                let mut flattened_shape = grad.dims().to_vec();
401                                flattened_shape[dim] = grad_len * step;
402                                let flattened = dilated.reshape(flattened_shape)?;
403                                
404                                flattened.narrow(dim, 0, span_len)?
405                            };
406                        
407                            let body_len = body_grad.dims()[dim];
408                            
409                            let left_pad = if start == 0 {
410                                None
411                            } else {
412                                let mut dims = arg_dims.to_vec();
413                                dims[dim] = start;
414                                Some(Tensor::zeros(dims)?)
415                            };
416                        
417                            let right_pad_len = arg_dims[dim] - start - body_len;
418                            let right_pad = if right_pad_len == 0 {
419                                None
420                            } else {
421                                let mut dims = arg_dims.to_vec();
422                                dims[dim] = right_pad_len;
423                                Some(Tensor::zeros(dims)?)
424                            };
425                        
426                            let arg_grad = match (left_pad, right_pad) {
427                                (None, None) => body_grad,
428                                (Some(l), None) => Tensor::cat(&[&l, &body_grad], dim)?,
429                                (None, Some(r)) => Tensor::cat(&[&body_grad, &r], dim)?,
430                                (Some(l), Some(r)) => Tensor::cat(&[&l, &body_grad, &r], dim)?,
431                            };
432                        
433                            let sum_grad = grads.or_insert(arg)?;
434                            sum_grad.add_(&arg_grad)?;
435                        }
436
437                        //=========================================================================================//
438                        //           Reshape
439                        //=========================================================================================//
440                        Op::Reshape(arg) => {
441                            let arg_grad = grad.reshape(arg.dims())?;
442                            let sum_grad = grads.or_insert(arg)?;
443                            sum_grad.add_(&arg_grad)?;
444                        }
445
446                        //=========================================================================================//
447                        //           Transpose
448                        //=========================================================================================//
449                        Op::Transpose(arg, dim1, dim2) => {
450                            let arg_grad = grad.transpose(*dim1, *dim2)?;
451                            let sum_grad = grads.or_insert(arg)?;
452                            sum_grad.add_(&arg_grad)?;
453                        }
454
455                        //=========================================================================================//
456                        //           Permute
457                        //=========================================================================================//
458                        Op::Permute(arg, dims) => {
459                            let mut inv_dims = vec![0; dims.len()];
460                            for (i, &dim_idx) in dims.iter().enumerate() {
461                                inv_dims[dim_idx] = i
462                            }
463                            let arg_grad = grad.permute(inv_dims)?;
464                            let sum_grad = grads.or_insert(arg)?;
465                            sum_grad.add_(&arg_grad)?;
466                        }
467
468                        //=========================================================================================//
469                        //           Cat
470                        //=========================================================================================//
471                        Op::Cat(args, dim) => {
472                            let mut start_idx = 0;
473                            for arg in args {
474                                let len = arg.dims()[*dim];
475                                let arg_grad = grad.narrow(*dim, start_idx, len)?;
476                                let sum_grad = grads.or_insert(arg)?;
477                                sum_grad.add_(&arg_grad)?;
478                                start_idx += len;
479                            }
480                        }
481
482                        //=========================================================================================//
483                        //           Copy
484                        //=========================================================================================//
485                        Op::Copy(arg) => {
486                            let sum_grad = grads.or_insert(arg)?;
487                            sum_grad.add_(&grad)?;
488                        }
489
490                        //=========================================================================================//
491                        //           IfElse
492                        //=========================================================================================//
493                        Op::IfElse(mask, tv, fv) => {
494                            if let Some(tv) = tv {
495                                let masked_grad = mask.if_else(&grad, T::zero())?;
496                                let sum_grad = grads.or_insert(tv)?;
497                                sum_grad.add_(&masked_grad)?;
498                            }
499
500                            if let Some(fv) = fv {
501                                let masked_grad = mask.if_else(T::zero(), &grad)?;
502                                let sum_grad = grads.or_insert(fv)?;
503                                sum_grad.add_(&masked_grad)?;
504                            }
505                        }
506
507                        //=========================================================================================//
508                        //           IndexSelect
509                        //=========================================================================================//
510                        Op::IndexSelect(arg, indexes, dim) => {
511                            let sum_grad = grads.or_insert(arg)?;
512                            *sum_grad = sum_grad.index_add(indexes.clone(), &grad, *dim)?;
513                        }
514
515                        //=========================================================================================//
516                        //           IndexAdd
517                        //=========================================================================================//
518                        Op::IndexAdd(init, indexes, src, dim) => {
519                            let init_sum_grad = grads.or_insert(init)?;
520                            *init_sum_grad = init_sum_grad.add(&grad)?;
521    
522                            let src_grad = grad.index_select(indexes.clone(), *dim)?;
523                            let src_sum_grad = grads.or_insert(src)?;
524                            *src_sum_grad = src_sum_grad.add(&src_grad)?;
525                        }
526
527                        //=========================================================================================//
528                        //           IndexAdd
529                        //=========================================================================================//
530                        #[allow(unused)]
531                        Op::ScatterAdd(init, indexes, src, dim) => {
532                            unimplemented!()
533                        }
534
535                        //=========================================================================================//
536                        //           Gather
537                        //=========================================================================================//
538                        Op::Gather(arg, indexes, dim) => {
539                            let arg_grad = grads.or_insert(arg)?;
540                            *arg_grad = arg_grad.scatter_add(indexes.clone(), &grad, *dim)?;
541                        }
542                    }
543                }
544            }
545        }
546
547        Ok(grads)
548    }
549
550    pub fn sorted_nodes(&self) -> Vec<&Tensor<T>> {
551        // The vec of sorted nodes is passed as an owned value rather than a mutable reference
552        // to get around some lifetime limitations.
553        fn walk<'a, T: FloatDType>(
554            node: &'a Tensor<T>,
555            nodes: Vec<&'a Tensor<T>>,
556            already_seen: &mut HashMap<TensorId, bool>,
557        ) -> (bool, Vec<&'a Tensor<T>>) {
558            if let Some(&tg) = already_seen.get(&node.id()) {
559                return (tg, nodes);
560            }
561            let mut track_grad = false;
562            let mut nodes = if node.is_leaf() {
563                track_grad = true;
564                nodes
565            } else if node.dtype().is_int() {
566                nodes
567            } else if let Some(op) = node.op() {
568                match op {
569                    | Op::Binary(lhs, rhs, _)
570                    | Op::Matmul(lhs, rhs) 
571                    | Op::IfElse(_, Some(lhs), Some(rhs))
572                    | Op::IndexAdd(lhs, _, rhs, _)
573                    | Op::ScatterAdd(lhs, _, rhs, _)
574                    => {
575                        let (tg, nodes) = walk(lhs, nodes, already_seen);
576                        track_grad |= tg;
577                        let (tg, nodes) = walk(rhs, nodes, already_seen);
578                        track_grad |= tg;
579                        nodes
580                    }
581                    
582                    | Op::Unary(_node, UnaryOp::Ceil)
583                    | Op::Unary(_node, UnaryOp::Floor)
584                    | Op::Unary(_node, UnaryOp::Round)
585                    | Op::Unary(_node, UnaryOp::Sign) => nodes,
586
587                    | Op::IfElse(_, None, None) => nodes,
588
589                    | Op::BinaryScalarLhs(_, node, _)
590                    | Op::BinaryScalarRhs(node, _, _)
591                    | Op::Broadcast(node)
592                    | Op::Unary(node, _)
593                    | Op::Pow(node, _)
594                    | Op::Reduce(node, _, _)
595                    | Op::Narrow(node, _, _, _)
596                    | Op::Slice(node, _, _, _, _)
597                    | Op::Reshape(node)
598                    | Op::Transpose(node, _, _)
599                    | Op::Permute(node, _)
600                    | Op::Copy(node) 
601                    | Op::Gather(node, _, _)
602                    | Op::IndexSelect(node, _, _)
603                    | Op::IfElse(_, Some(node), None)
604                    | Op::IfElse(_, None, Some(node)) => {
605                        let (tg, nodes) = walk(node, nodes, already_seen);
606                        track_grad |= tg;
607                        nodes
608                    }
609
610                    | Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
611                        let (tg, nodes) = walk(arg, nodes, already_seen);
612                        track_grad |= tg;
613                        nodes
614                    }),
615                }
616            } else {
617                nodes
618            };
619            already_seen.insert(node.id(), track_grad);
620            if track_grad {
621                nodes.push(node);
622            }
623            (track_grad, nodes)
624        }
625        let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
626        nodes.reverse();
627        nodes
628    }
629
630    fn broadcast_back(arg: &Tensor<T>, node: &Tensor<T>, reduced_dims: &[usize]) -> crate::Result<Tensor<T>> {
631        if arg.rank() == node.rank() {
632            node.broadcast_as(arg.shape())
633        } else {
634            node.reshape(reduced_dims)?.broadcast_as(arg.shape())
635        }
636    }    
637}