lumen_core/grad/backprop.rs
1use std::collections::HashMap;
2use crate::{FloatDType, Tensor, TensorId};
3
4use super::{BinaryOp, GradStore, Op, ReduceOp, UnaryOp};
5
6impl<T: FloatDType> Tensor<T> {
7
8 pub fn backward(&self) -> crate::Result<GradStore<T>> {
9 let _guard = crate::NoGradGuard::new();
10
11 let sorted_nodes = self.sorted_nodes();
12 let mut grads = GradStore::new();
13 grads.insert(self, self.ones_like()?);
14
15 for node in sorted_nodes.iter() {
16 match node.op() {
17 None => {
18 assert!(node.is_leaf());
19 continue
20 }
21 Some(op) => {
22 let grad = grads
23 .remove(node)
24 .expect("grad not populated");
25
26 match op {
27 //=========================================================================================//
28 // Binary
29 //=========================================================================================//
30 Op::Binary(lhs, rhs, BinaryOp::Add) => {
31 let lhs_sum_grad = grads.or_insert(lhs)?;
32 lhs_sum_grad.add_(&grad)?;
33 let rhs_sum_grad = grads.or_insert(rhs)?;
34 rhs_sum_grad.add_(&grad)?;
35 }
36 Op::Binary(lhs, rhs, BinaryOp::Sub) => {
37 let lhs_sum_grad = grads.or_insert(lhs)?;
38 lhs_sum_grad.add_(&grad)?;
39 let rhs_sum_grad = grads.or_insert(rhs)?;
40 rhs_sum_grad.sub_(&grad)?;
41 }
42 Op::Binary(lhs, rhs, BinaryOp::Mul) => {
43 let lhs_grad = grad.mul(rhs)?;
44 let lhs_sum_grad = grads.or_insert(lhs)?;
45 lhs_sum_grad.add_(&lhs_grad)?;
46
47 let rhs_grad = grad.mul(lhs)?;
48 let rhs_sum_grad = grads.or_insert(rhs)?;
49 rhs_sum_grad.add_(&rhs_grad)?;
50 }
51 Op::Binary(lhs, rhs, BinaryOp::Div) => {
52 let lhs_grad = grad.div(rhs)?;
53 let lhs_sum_grad = grads.or_insert(lhs)?;
54 lhs_sum_grad.add_(&lhs_grad)?;
55
56 let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
57 let rhs_sum_grad = grads.or_insert(rhs)?;
58 rhs_sum_grad.sub_(&rhs_grad)?;
59 }
60 Op::Binary(lhs, rhs, BinaryOp::Minimum)
61 | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
62 let mask_lhs = (*node).eq(lhs)?.cast()?;
63 let mask_rhs = (*node).eq(rhs)?.cast()?;
64
65 // If both masks are 1 one the same point, we want to scale the
66 // gradient by 0.5 rather than 1.
67 let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + T::one()))?;
68 let lhs_sum_grad = grads.or_insert(lhs)?;
69 lhs_sum_grad.add_(&lhs_grad)?;
70
71 let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + T::one()))?;
72 let rhs_sum_grad = grads.or_insert(rhs)?;
73 rhs_sum_grad.add_(&rhs_grad)?;
74 }
75
76 //=========================================================================================//
77 // BinaryScalarRhs
78 //=========================================================================================//
79 Op::BinaryScalarRhs(lhs, _, BinaryOp::Add) => {
80 // y = x + c => dy/dx = 1
81 let lhs_sum_grad = grads.or_insert(lhs)?;
82 lhs_sum_grad.add_(&grad)?;
83 }
84 Op::BinaryScalarRhs(lhs, _, BinaryOp::Sub) => {
85 // y = x - c => dy/dx = 1
86 let lhs_sum_grad = grads.or_insert(lhs)?;
87 lhs_sum_grad.add_(&grad)?;
88 }
89 Op::BinaryScalarRhs(lhs, rhs, BinaryOp::Mul) => {
90 // y = x * c => dy/dx = c
91 let lhs_grad = grad.mul_scalar(*rhs)?;
92 let lhs_sum_grad = grads.or_insert(lhs)?;
93 lhs_sum_grad.add_(&lhs_grad)?;
94 }
95 Op::BinaryScalarRhs(lhs, rhs, BinaryOp::Div) => {
96 // y = x / c => dy/dx = 1/c
97 let lhs_grad = grad.div_scalar(*rhs)?;
98 let lhs_sum_grad = grads.or_insert(lhs)?;
99 lhs_sum_grad.add_(&lhs_grad)?;
100 }
101 Op::BinaryScalarRhs(lhs, rhs, BinaryOp::Maximum) |
102 Op::BinaryScalarRhs(lhs, rhs, BinaryOp::Minimum) => {
103 let mask_lhs = (*node).eq(lhs)?.cast()?;
104 let mask_rhs = (*node).eq(*rhs)?.cast()?;
105 let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + T::one()))?;
106 let lhs_sum_grad = grads.or_insert(lhs)?;
107 lhs_sum_grad.add_(&lhs_grad)?;
108 }
109
110 //=========================================================================================//
111 // BinaryScalarLhs
112 //=========================================================================================//
113 Op::BinaryScalarLhs(_, rhs, BinaryOp::Add) => {
114 // y = c + x => dy/dx = 1
115 let rhs_sum_grad = grads.or_insert(rhs)?;
116 rhs_sum_grad.add_(&grad)?;
117 }
118 Op::BinaryScalarLhs(_, rhs, BinaryOp::Sub) => {
119 // y = c - x => dy/dx = -1
120 let rhs_sum_grad = grads.or_insert(rhs)?;
121 rhs_sum_grad.sub_(&grad)?;
122 }
123 Op::BinaryScalarLhs(lhs, rhs, BinaryOp::Mul) => {
124 // y = c * x => dy/dx = c
125 let rhs_grad = grad.mul_scalar(*lhs)?;
126 let rhs_sum_grad = grads.or_insert(rhs)?;
127 rhs_sum_grad.add_(&rhs_grad)?;
128 }
129 Op::BinaryScalarLhs(lhs, rhs, BinaryOp::Div) => {
130 // y = c / x = c * x^(-1)
131 // dy/dx = -c * x^(-2) = -c / (x^2)
132 // grad_input = grad * (-c / x^2)
133 let numerator = grad.mul_scalar(-*lhs)?;
134 let denominator = rhs.mul(rhs)?;
135 let rhs_grad = numerator.div(&denominator)?;
136
137 let rhs_sum_grad = grads.or_insert(rhs)?;
138 rhs_sum_grad.add_(&rhs_grad)?;
139 }
140 Op::BinaryScalarLhs(lhs, rhs, BinaryOp::Maximum) |
141 Op::BinaryScalarLhs(lhs, rhs, BinaryOp::Minimum) => {
142 let mask_lhs = (*node).eq(*lhs)?.cast()?;
143 let mask_rhs = (*node).eq(rhs)?.cast()?;
144 let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + T::one()))?;
145 let rhs_sum_grad = grads.or_insert(rhs)?;
146 rhs_sum_grad.add_(&rhs_grad)?;
147 }
148
149 //=========================================================================================//
150 // Unary
151 //=========================================================================================//
152 Op::Unary(_, UnaryOp::Ceil) => Err(crate::Error::BackwardNotSupported("ceil"))?,
153 Op::Unary(_, UnaryOp::Floor) => Err(crate::Error::BackwardNotSupported("floor"))?,
154 Op::Unary(_, UnaryOp::Round) => Err(crate::Error::BackwardNotSupported("round"))?,
155 Op::Unary(_, UnaryOp::Sign) => Err(crate::Error::BackwardNotSupported("sign"))?,
156 Op::Unary(arg, UnaryOp::Exp) => {
157 let sum_grad = grads.or_insert(arg)?;
158 sum_grad.add_(&(&grad * *node))?;
159 }
160 Op::Unary(arg, UnaryOp::Ln) => {
161 let sum_grad = grads.or_insert(arg)?;
162 sum_grad.add_(&(grad / arg))?;
163 }
164 Op::Unary(arg, UnaryOp::Sin) => {
165 let sum_grad = grads.or_insert(arg)?;
166 sum_grad.add_(&(&grad * arg.cos()?))?;
167 }
168 Op::Unary(arg, UnaryOp::Cos) => {
169 let sum_grad = grads.or_insert(arg)?;
170 // y = cos(x) -> y' = -sin(x) -> grad = grad * -sin(x) -> grad -= grad * sin(x)
171 sum_grad.sub_(&(&grad * arg.sin()?))?;
172 }
173 Op::Unary(arg, UnaryOp::Tanh) => {
174 let sum_grad = grads.or_insert(arg)?;
175 let minus_dtanh = node.sqr()? - T::one();
176 // y = tanh(x) -> y' = 1 - tanh^2(x) = 1 - y^2 = -(y^2 - 1)
177 sum_grad.sub_(&(&grad * &minus_dtanh))?;
178 }
179 Op::Unary(arg, UnaryOp::Sqr) => {
180 let arg_grad = arg.mul(&grad)?.affine(T::two(), T::zero())?;
181 let sum_grad = grads.or_insert(arg)?;
182 sum_grad.add_(&arg_grad)?;
183 }
184 Op::Unary(arg, UnaryOp::Sqrt) => {
185 let arg_grad = grad.div(*node)?.affine(T::half(), T::zero())?;
186 let sum_grad = grads.or_insert(arg)?;
187 sum_grad.add_(&arg_grad)?;
188 }
189 Op::Unary(arg, UnaryOp::Abs) => {
190 let sum_grad = grads.or_insert(arg)?;
191 let ones = arg.ones_like()?;
192 let abs_grad = arg.ge(&arg.zeros_like()?)?.if_else(&ones, ones.neg()?)?;
193 sum_grad.add_(&(&grad * abs_grad))?;
194 }
195 Op::Unary(arg, UnaryOp::Neg) => {
196 let sum_grad = grads.or_insert(arg)?;
197 // dy/dx = -1 -> sub(grad)
198 sum_grad.sub_(&grad)?;
199 }
200 Op::Unary(arg, UnaryOp::Recip) => {
201 let sum_grad = grads.or_insert(arg)?;
202 let grad = grad / arg.sqr()?;
203 sum_grad.sub_(&grad)?;
204 }
205 Op::Unary(arg, UnaryOp::Gelu) => {
206 let sum_grad = grads.or_insert(arg)?;
207 let cube = arg.pow(T::from_f64(3.))?;
208 let tanh = (&cube * T::from_f64(0.0356774) + (arg * T::from_f64(0.797885))).tanh()?;
209 let gelu_grad =
210 &tanh / T::two()
211 + (cube * T::from_f64(0.0535161) + arg * T::from_f64(0.398942)) * (tanh.pow(T::two())?.neg()? + T::one())
212 + T::half();
213 sum_grad.add_(&(&grad * gelu_grad))?;
214 }
215 Op::Unary(arg, UnaryOp::Erf) => {
216 let sum_grad = grads.or_insert(arg)?;
217 // d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
218 let erf_grad = arg.sqr()?.neg()?.exp()? * (T::two() / T::pi().sqrt());
219 sum_grad.add_(&(&grad * erf_grad))?;
220 }
221 Op::Unary(arg, UnaryOp::GeluErf) => {
222 let sum_grad = grads.or_insert(arg)?;
223 // d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
224 let neg_half_square = arg.sqr()?.neg()? / T::two();
225 let scaled_exp_arg = T::from_f64(0.398942) * neg_half_square.exp()? * arg;
226 let arg_scaled_sqrt = arg / T::two().sqrt();
227 let erf_scaled_sqrt = arg_scaled_sqrt.erf()? / T::two();
228 let gelu_erf_grad = scaled_exp_arg + erf_scaled_sqrt + T::half();
229 sum_grad.add_(&(&grad * gelu_erf_grad))?;
230 }
231 Op::Unary(arg, UnaryOp::Relu) => {
232 let sum_grad = grads.or_insert(arg)?;
233 let relu_grad = arg.ge(&arg.zeros_like()?)?.cast::<T>()?;
234 sum_grad.add_(&(&grad * relu_grad))?;
235 }
236 Op::Unary(arg, UnaryOp::Silu) => {
237 let sum_grad = grads.or_insert(arg)?;
238 // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
239 let sigmoid_arg = (arg.neg()?.exp()? + T::one()).recip()?;
240 let silu_grad = &sigmoid_arg * (T::one() - *node) + *node;
241 sum_grad.add_(&(&grad * silu_grad))?;
242 }
243 Op::Unary(arg, UnaryOp::Sigmoid) => {
244 let sum_grad = grads.or_insert(arg)?;
245 // y = sigmoid(x) = *node
246 let local_deriv = *node * (T::one() - *node);
247 sum_grad.add_(&(&grad * local_deriv))?;
248 }
249 Op::Unary(arg, UnaryOp::LeakyRelu(negative_slope)) => {
250 let sum_grad = grads.or_insert(arg)?;
251 let mask = arg.ge(&arg.zeros_like()?)?.cast::<T>()?;
252
253 let ones = mask.ones_like()?;
254 let inv_mask = ones.sub(&mask)?;
255
256 let slope_part = inv_mask.mul_scalar(*negative_slope)?;
257 let local_deriv = mask.add(&slope_part)?;
258
259 sum_grad.add_(&(&grad * local_deriv))?;
260 }
261
262 //=========================================================================================//
263 // Matmul
264 //=========================================================================================//
265 Op::Matmul(lhs, rhs) => {
266 let lhs_grad = grad.matmul(&rhs.transpose_last()?)?;
267 let lhs_sum_grad = grads.or_insert(lhs)?;
268 lhs_sum_grad.add_(&lhs_grad)?;
269
270 let rhs_grad = lhs.transpose_last()?.matmul(&grad)?;
271 let rhs_sum_grad = grads.or_insert(rhs)?;
272 rhs_sum_grad.add_(&rhs_grad)?;
273 }
274
275 //=========================================================================================//
276 // Pow
277 //=========================================================================================//
278 Op::Pow(arg, e) => {
279 let arg_grad = &(grad * arg.pow(*e - T::one())?) * *e;
280 let sum_grad = grads.or_insert(arg)?;
281 sum_grad.add_(&arg_grad)?;
282 }
283
284 //=========================================================================================//
285 // Reduce
286 //=========================================================================================//
287 Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
288 let grad = Self::broadcast_back(arg, &grad, reduced_dims)?;
289 let sum_grad = grads.or_insert(arg)?;
290 sum_grad.add_(&grad)?;
291 }
292 Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
293 let node = Self::broadcast_back(arg, node, reduced_dims)?;
294 let grad = Self::broadcast_back(arg, &grad, reduced_dims)?;
295 let grad = node.eq(arg)?.cast()?.mul(&grad)?;
296 let sum_grad = grads.or_insert(arg)?;
297 sum_grad.add_(&grad.broadcast_as(sum_grad.dims())?)?;
298 }
299 Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
300 let node = Self::broadcast_back(arg, node, reduced_dims)?;
301 let grad = Self::broadcast_back(arg, &grad, reduced_dims)?;
302 let grad = node.eq(arg)?.cast()?.mul(&grad)?;
303 let sum_grad = grads.or_insert(arg)?;
304 sum_grad.add_(&grad.broadcast_as(sum_grad.dims())?)?;
305 }
306 Op::Reduce(arg, ReduceOp::Mean, reduced_dims) => {
307 let grad_output = Self::broadcast_back(arg, &grad, reduced_dims)?;
308 let n = arg.element_count() / node.element_count();
309
310 // grad_input = grad_output / n
311 let grad_input = grad_output / T::from_usize(n);
312
313 let sum_grad = grads.or_insert(arg)?;
314 sum_grad.add_(&grad_input)?;
315 }
316
317 //=========================================================================================//
318 // Broadcast
319 //=========================================================================================//
320 Op::Broadcast(arg) => {
321 let arg_dims = arg.dims();
322 let node_dims = node.dims();
323 let left_dims = node_dims.len() - arg_dims.len();
324 let mut sum_dims: Vec<usize> = (0..left_dims).collect();
325 for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
326 .iter()
327 .zip(arg_dims.iter())
328 .enumerate()
329 {
330 if node_dim != arg_dim {
331 sum_dims.push(dim + left_dims)
332 }
333 }
334
335 let mut arg_grad = grad;
336 for &dim in sum_dims.iter() {
337 arg_grad = arg_grad.sum_keepdim(dim)?;
338 }
339
340 for _i in 0..left_dims {
341 arg_grad = arg_grad.squeeze(0)?
342 }
343 let sum_grad = grads.or_insert(arg)?;
344 sum_grad.add_(&arg_grad.broadcast_as(sum_grad.dims())?)?;
345 }
346
347 //=========================================================================================//
348 // Narrow
349 //=========================================================================================//
350 &Op::Narrow(ref arg, dim, start_idx, len) => {
351 let arg_dims = arg.dims();
352 let left_pad = if start_idx == 0 {
353 None
354 } else {
355 let mut dims = arg_dims.to_vec();
356 dims[dim] = start_idx;
357 Some(Tensor::zeros(dims)?)
358 };
359 let right_pad = arg_dims[dim] - start_idx - len;
360 let right_pad = if right_pad == 0 {
361 None
362 } else {
363 let mut dims = arg_dims.to_vec();
364 dims[dim] = right_pad;
365 Some(Tensor::zeros(dims)?)
366 };
367 let arg_grad = match (left_pad, right_pad) {
368 (None, None) => grad,
369 (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
370 (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
371 (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
372 };
373 let sum_grad = grads.or_insert(arg)?;
374 sum_grad.add_(&arg_grad)?;
375 }
376
377 //=========================================================================================//
378 // Slice
379 //=========================================================================================//
380 &Op::Slice(ref arg, dim, start, _end, step) => {
381 let arg_dims = arg.dims();
382
383 let body_grad = if step == 1 {
384 // Narrow
385 grad
386 } else {
387 let grad_len = grad.dims()[dim];
388 let span_len = if grad_len > 0 { (grad_len - 1) * step + 1 } else { 0 };
389
390 let mut unsqueezed_shape = grad.dims().to_vec();
391 unsqueezed_shape.insert(dim + 1, 1);
392 let grad_unsqueezed = grad.reshape(&unsqueezed_shape)?;
393
394 let mut zeros_shape = unsqueezed_shape.clone();
395 zeros_shape[dim + 1] = step - 1;
396 let zeros_gap = Tensor::zeros(zeros_shape)?;
397
398 let dilated = Tensor::cat(&[&grad_unsqueezed, &zeros_gap], dim + 1)?;
399
400 let mut flattened_shape = grad.dims().to_vec();
401 flattened_shape[dim] = grad_len * step;
402 let flattened = dilated.reshape(flattened_shape)?;
403
404 flattened.narrow(dim, 0, span_len)?
405 };
406
407 let body_len = body_grad.dims()[dim];
408
409 let left_pad = if start == 0 {
410 None
411 } else {
412 let mut dims = arg_dims.to_vec();
413 dims[dim] = start;
414 Some(Tensor::zeros(dims)?)
415 };
416
417 let right_pad_len = arg_dims[dim] - start - body_len;
418 let right_pad = if right_pad_len == 0 {
419 None
420 } else {
421 let mut dims = arg_dims.to_vec();
422 dims[dim] = right_pad_len;
423 Some(Tensor::zeros(dims)?)
424 };
425
426 let arg_grad = match (left_pad, right_pad) {
427 (None, None) => body_grad,
428 (Some(l), None) => Tensor::cat(&[&l, &body_grad], dim)?,
429 (None, Some(r)) => Tensor::cat(&[&body_grad, &r], dim)?,
430 (Some(l), Some(r)) => Tensor::cat(&[&l, &body_grad, &r], dim)?,
431 };
432
433 let sum_grad = grads.or_insert(arg)?;
434 sum_grad.add_(&arg_grad)?;
435 }
436
437 //=========================================================================================//
438 // Reshape
439 //=========================================================================================//
440 Op::Reshape(arg) => {
441 let arg_grad = grad.reshape(arg.dims())?;
442 let sum_grad = grads.or_insert(arg)?;
443 sum_grad.add_(&arg_grad)?;
444 }
445
446 //=========================================================================================//
447 // Transpose
448 //=========================================================================================//
449 Op::Transpose(arg, dim1, dim2) => {
450 let arg_grad = grad.transpose(*dim1, *dim2)?;
451 let sum_grad = grads.or_insert(arg)?;
452 sum_grad.add_(&arg_grad)?;
453 }
454
455 //=========================================================================================//
456 // Permute
457 //=========================================================================================//
458 Op::Permute(arg, dims) => {
459 let mut inv_dims = vec![0; dims.len()];
460 for (i, &dim_idx) in dims.iter().enumerate() {
461 inv_dims[dim_idx] = i
462 }
463 let arg_grad = grad.permute(inv_dims)?;
464 let sum_grad = grads.or_insert(arg)?;
465 sum_grad.add_(&arg_grad)?;
466 }
467
468 //=========================================================================================//
469 // Cat
470 //=========================================================================================//
471 Op::Cat(args, dim) => {
472 let mut start_idx = 0;
473 for arg in args {
474 let len = arg.dims()[*dim];
475 let arg_grad = grad.narrow(*dim, start_idx, len)?;
476 let sum_grad = grads.or_insert(arg)?;
477 sum_grad.add_(&arg_grad)?;
478 start_idx += len;
479 }
480 }
481
482 //=========================================================================================//
483 // Copy
484 //=========================================================================================//
485 Op::Copy(arg) => {
486 let sum_grad = grads.or_insert(arg)?;
487 sum_grad.add_(&grad)?;
488 }
489
490 //=========================================================================================//
491 // IfElse
492 //=========================================================================================//
493 Op::IfElse(mask, tv, fv) => {
494 if let Some(tv) = tv {
495 let masked_grad = mask.if_else(&grad, T::zero())?;
496 let sum_grad = grads.or_insert(tv)?;
497 sum_grad.add_(&masked_grad)?;
498 }
499
500 if let Some(fv) = fv {
501 let masked_grad = mask.if_else(T::zero(), &grad)?;
502 let sum_grad = grads.or_insert(fv)?;
503 sum_grad.add_(&masked_grad)?;
504 }
505 }
506
507 //=========================================================================================//
508 // IndexSelect
509 //=========================================================================================//
510 Op::IndexSelect(arg, indexes, dim) => {
511 let sum_grad = grads.or_insert(arg)?;
512 *sum_grad = sum_grad.index_add(indexes.clone(), &grad, *dim)?;
513 }
514
515 //=========================================================================================//
516 // IndexAdd
517 //=========================================================================================//
518 Op::IndexAdd(init, indexes, src, dim) => {
519 let init_sum_grad = grads.or_insert(init)?;
520 *init_sum_grad = init_sum_grad.add(&grad)?;
521
522 let src_grad = grad.index_select(indexes.clone(), *dim)?;
523 let src_sum_grad = grads.or_insert(src)?;
524 *src_sum_grad = src_sum_grad.add(&src_grad)?;
525 }
526
527 //=========================================================================================//
528 // IndexAdd
529 //=========================================================================================//
530 #[allow(unused)]
531 Op::ScatterAdd(init, indexes, src, dim) => {
532 unimplemented!()
533 }
534
535 //=========================================================================================//
536 // Gather
537 //=========================================================================================//
538 Op::Gather(arg, indexes, dim) => {
539 let arg_grad = grads.or_insert(arg)?;
540 *arg_grad = arg_grad.scatter_add(indexes.clone(), &grad, *dim)?;
541 }
542 }
543 }
544 }
545 }
546
547 Ok(grads)
548 }
549
550 pub fn sorted_nodes(&self) -> Vec<&Tensor<T>> {
551 // The vec of sorted nodes is passed as an owned value rather than a mutable reference
552 // to get around some lifetime limitations.
553 fn walk<'a, T: FloatDType>(
554 node: &'a Tensor<T>,
555 nodes: Vec<&'a Tensor<T>>,
556 already_seen: &mut HashMap<TensorId, bool>,
557 ) -> (bool, Vec<&'a Tensor<T>>) {
558 if let Some(&tg) = already_seen.get(&node.id()) {
559 return (tg, nodes);
560 }
561 let mut track_grad = false;
562 let mut nodes = if node.is_leaf() {
563 track_grad = true;
564 nodes
565 } else if node.dtype().is_int() {
566 nodes
567 } else if let Some(op) = node.op() {
568 match op {
569 | Op::Binary(lhs, rhs, _)
570 | Op::Matmul(lhs, rhs)
571 | Op::IfElse(_, Some(lhs), Some(rhs))
572 | Op::IndexAdd(lhs, _, rhs, _)
573 | Op::ScatterAdd(lhs, _, rhs, _)
574 => {
575 let (tg, nodes) = walk(lhs, nodes, already_seen);
576 track_grad |= tg;
577 let (tg, nodes) = walk(rhs, nodes, already_seen);
578 track_grad |= tg;
579 nodes
580 }
581
582 | Op::Unary(_node, UnaryOp::Ceil)
583 | Op::Unary(_node, UnaryOp::Floor)
584 | Op::Unary(_node, UnaryOp::Round)
585 | Op::Unary(_node, UnaryOp::Sign) => nodes,
586
587 | Op::IfElse(_, None, None) => nodes,
588
589 | Op::BinaryScalarLhs(_, node, _)
590 | Op::BinaryScalarRhs(node, _, _)
591 | Op::Broadcast(node)
592 | Op::Unary(node, _)
593 | Op::Pow(node, _)
594 | Op::Reduce(node, _, _)
595 | Op::Narrow(node, _, _, _)
596 | Op::Slice(node, _, _, _, _)
597 | Op::Reshape(node)
598 | Op::Transpose(node, _, _)
599 | Op::Permute(node, _)
600 | Op::Copy(node)
601 | Op::Gather(node, _, _)
602 | Op::IndexSelect(node, _, _)
603 | Op::IfElse(_, Some(node), None)
604 | Op::IfElse(_, None, Some(node)) => {
605 let (tg, nodes) = walk(node, nodes, already_seen);
606 track_grad |= tg;
607 nodes
608 }
609
610 | Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
611 let (tg, nodes) = walk(arg, nodes, already_seen);
612 track_grad |= tg;
613 nodes
614 }),
615 }
616 } else {
617 nodes
618 };
619 already_seen.insert(node.id(), track_grad);
620 if track_grad {
621 nodes.push(node);
622 }
623 (track_grad, nodes)
624 }
625 let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
626 nodes.reverse();
627 nodes
628 }
629
630 fn broadcast_back(arg: &Tensor<T>, node: &Tensor<T>, reduced_dims: &[usize]) -> crate::Result<Tensor<T>> {
631 if arg.rank() == node.rank() {
632 node.broadcast_as(arg.shape())
633 } else {
634 node.reshape(reduced_dims)?.broadcast_as(arg.shape())
635 }
636 }
637}