1use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
2use crate::{Error, Result, Tensor, TensorId};
3use std::collections::HashMap;
4
5fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
8 if arg.rank() == node.rank() {
9 node.broadcast_as(arg.shape())
11 } else {
12 node.reshape(reduced_dims)?.broadcast_as(arg.shape())
15 }
16}
17
18impl Tensor {
19 fn sorted_nodes(&self) -> Vec<&Tensor> {
24 fn walk<'a>(
27 node: &'a Tensor,
28 nodes: Vec<&'a Tensor>,
29 already_seen: &mut HashMap<TensorId, bool>,
30 ) -> (bool, Vec<&'a Tensor>) {
31 if let Some(&tg) = already_seen.get(&node.id()) {
32 return (tg, nodes);
33 }
34 let mut track_grad = false;
35 let mut nodes = if node.is_variable() {
36 track_grad = true;
38 nodes
39 } else if node.dtype().is_int() {
40 nodes
41 } else if let Some(op) = node.op() {
42 match op {
43 Op::IndexAdd(t1, t2, t3, _)
44 | Op::ScatterAdd(t1, t2, t3, _)
45 | Op::CustomOp3(t1, t2, t3, _)
46 | Op::WhereCond(t1, t2, t3) => {
47 let (tg, nodes) = walk(t1, nodes, already_seen);
48 track_grad |= tg;
49 let (tg, nodes) = walk(t2, nodes, already_seen);
50 track_grad |= tg;
51 let (tg, nodes) = walk(t3, nodes, already_seen);
52 track_grad |= tg;
53 nodes
54 }
55 Op::Conv1D {
56 arg: lhs,
57 kernel: rhs,
58 ..
59 }
60 | Op::Conv2D {
61 arg: lhs,
62 kernel: rhs,
63 ..
64 }
65 | Op::ConvTranspose2D {
66 arg: lhs,
67 kernel: rhs,
68 ..
69 }
70 | Op::CustomOp2(lhs, rhs, _)
71 | Op::Binary(lhs, rhs, _)
72 | Op::Gather(lhs, rhs, _)
73 | Op::IndexSelect(lhs, rhs, _)
74 | Op::Matmul(lhs, rhs)
75 | Op::SliceScatter0(lhs, rhs, _) => {
76 let (tg, nodes) = walk(lhs, nodes, already_seen);
77 track_grad |= tg;
78 let (tg, nodes) = walk(rhs, nodes, already_seen);
79 track_grad |= tg;
80 nodes
81 }
82 Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
83 let (tg, nodes) = walk(arg, nodes, already_seen);
84 track_grad |= tg;
85 nodes
86 }),
87 Op::Affine { arg, mul, .. } => {
88 if *mul == 0. {
89 nodes
90 } else {
91 let (tg, nodes) = walk(arg, nodes, already_seen);
92 track_grad |= tg;
93 nodes
94 }
95 }
96 Op::Unary(_node, UnaryOp::Ceil)
97 | Op::Unary(_node, UnaryOp::Floor)
98 | Op::Unary(_node, UnaryOp::Round) => nodes,
99 Op::Reshape(node)
100 | Op::UpsampleNearest1D(node)
101 | Op::UpsampleNearest2D(node)
102 | Op::AvgPool2D { arg: node, .. }
103 | Op::MaxPool2D { arg: node, .. }
104 | Op::Copy(node)
105 | Op::Broadcast(node)
106 | Op::Cmp(node, _)
107 | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
108 | Op::ToDevice(node)
109 | Op::Transpose(node, _, _)
110 | Op::Permute(node, _)
111 | Op::Narrow(node, _, _, _)
112 | Op::Unary(node, _)
113 | Op::Elu(node, _)
114 | Op::Powf(node, _)
115 | Op::CustomOp1(node, _) => {
116 let (tg, nodes) = walk(node, nodes, already_seen);
117 track_grad |= tg;
118 nodes
119 }
120 Op::ToDType(node) => {
121 if node.dtype().is_float() {
122 let (tg, nodes) = walk(node, nodes, already_seen);
123 track_grad |= tg;
124 nodes
125 } else {
126 nodes
127 }
128 }
129 Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
130 }
131 } else {
132 nodes
133 };
134 already_seen.insert(node.id(), track_grad);
135 if track_grad {
136 nodes.push(node);
137 }
138 (track_grad, nodes)
139 }
140 let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
141 nodes.reverse();
142 nodes
143 }
144
145 pub fn backward(&self) -> Result<GradStore> {
146 let sorted_nodes = self.sorted_nodes();
147 let mut grads = GradStore::new();
148 grads.insert(self, self.ones_like()?.contiguous()?);
149 for node in sorted_nodes.iter() {
150 if node.is_variable() {
151 continue;
152 }
153 let grad = grads.remove(node).unwrap();
154 if let Some(op) = node.op() {
158 match op {
159 Op::Binary(lhs, rhs, BinaryOp::Add) => {
160 let lhs_sum_grad = grads.or_insert(lhs)?;
161 *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
162 let rhs_sum_grad = grads.or_insert(rhs)?;
163 *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
164 }
165 Op::Binary(lhs, rhs, BinaryOp::Sub) => {
166 let lhs_sum_grad = grads.or_insert(lhs)?;
167 *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
168 let rhs_sum_grad = grads.or_insert(rhs)?;
169 *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
170 }
171 Op::Binary(lhs, rhs, BinaryOp::Mul) => {
172 let lhs_grad = grad.mul(rhs)?;
173 let lhs_sum_grad = grads.or_insert(lhs)?;
174 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
175 let rhs_grad = grad.mul(lhs)?;
176 let rhs_sum_grad = grads.or_insert(rhs)?;
177 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
178 }
179 Op::Binary(lhs, rhs, BinaryOp::Div) => {
180 let lhs_grad = grad.div(rhs)?;
181 let lhs_sum_grad = grads.or_insert(lhs)?;
182 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
183 let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
184 let rhs_sum_grad = grads.or_insert(rhs)?;
185 *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
186 }
187 Op::Binary(lhs, rhs, BinaryOp::Minimum)
188 | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
189 let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
190 let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
191
192 let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
195 let lhs_sum_grad = grads.or_insert(lhs)?;
196 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
197
198 let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
199 let rhs_sum_grad = grads.or_insert(rhs)?;
200 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
201 }
202 Op::WhereCond(pred, t, f) => {
203 let zeros = grad.zeros_like()?;
204 let t_sum_grad = grads.or_insert(t)?;
205 let t_grad = pred.where_cond(&grad, &zeros)?;
206 *t_sum_grad = t_sum_grad.add(&t_grad)?;
207 let f_sum_grad = grads.or_insert(f)?;
208 let f_grad = pred.where_cond(&zeros, &grad)?;
209 *f_sum_grad = f_sum_grad.add(&f_grad)?;
210 }
211 Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
212 Op::Conv2D {
213 arg,
214 kernel,
215 padding,
216 stride,
217 dilation,
218 } => {
219 let grad_h = grad.dim(2)?;
222 let k_h = kernel.dim(2)?;
223 let out_size =
224 (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
225 let out_padding = arg.dim(2)? - out_size;
226 let grad_arg = grad.conv_transpose2d(
227 kernel,
228 *padding,
229 out_padding,
230 *stride,
231 *dilation,
232 )?;
233 let sum_grad = grads.or_insert(arg)?;
234 *sum_grad = sum_grad.add(&grad_arg)?;
235
236 let grad_kernel = arg
237 .transpose(0, 1)?
238 .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
239 .transpose(0, 1)?;
240 let sum_grad = grads.or_insert(kernel)?;
241 let (_, _, k0, k1) = kernel.dims4()?;
242 let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
243 let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
244 grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
245 } else {
246 grad_kernel
247 };
248 *sum_grad = sum_grad.add(&grad_kernel)?;
249 }
250 Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
251 op: "conv-transpose2d",
252 })?,
253 Op::AvgPool2D {
254 arg,
255 kernel_size,
256 stride,
257 } => {
258 if kernel_size != stride {
259 crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
260 }
261 let (_n, _c, h, w) = arg.dims4()?;
262 let grad_arg = grad.upsample_nearest2d(h, w)?;
263 let grad_arg =
264 (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
265 let sum_grad = grads.or_insert(arg)?;
266 *sum_grad = sum_grad.add(&grad_arg)?;
267 }
268 Op::MaxPool2D {
269 arg,
270 kernel_size,
271 stride,
272 } => {
273 if kernel_size != stride {
274 crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
275 }
276 let (_n, _c, h, w) = arg.dims4()?;
277 let node_upsampled = node.upsample_nearest2d(h, w)?;
282 let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
283 let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
284 let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
285 let sum_grad = grads.or_insert(arg)?;
286 *sum_grad = sum_grad.add(&grad_arg)?;
287 }
288 Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
289 op: "upsample-nearest1d",
290 })?,
291 Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
292 op: "upsample-nearest2d",
293 })?,
294 Op::SliceScatter0(lhs, rhs, start_rhs) => {
295 let rhs_sum_grad = grads.or_insert(rhs)?;
296 let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
297 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
298
299 let lhs_sum_grad = grads.or_insert(lhs)?;
300 let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
301 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
302 }
303 Op::Gather(arg, indexes, dim) => {
304 let sum_grad = grads.or_insert(arg)?;
305 *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
306 }
307 Op::ScatterAdd(init, indexes, src, dim) => {
308 let init_sum_grad = grads.or_insert(init)?;
309 *init_sum_grad = init_sum_grad.add(&grad)?;
310
311 let src_grad = grad.gather(indexes, *dim)?;
312 let src_sum_grad = grads.or_insert(src)?;
313 *src_sum_grad = src_sum_grad.add(&src_grad)?;
314 }
315 Op::IndexAdd(init, indexes, src, dim) => {
316 let init_sum_grad = grads.or_insert(init)?;
317 *init_sum_grad = init_sum_grad.add(&grad)?;
318
319 let src_grad = grad.index_select(indexes, *dim)?;
320 let src_sum_grad = grads.or_insert(src)?;
321 *src_sum_grad = src_sum_grad.add(&src_grad)?;
322 }
323 Op::IndexSelect(arg, indexes, dim) => {
324 let sum_grad = grads.or_insert(arg)?;
325 *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
326 }
327 Op::Matmul(lhs, rhs) => {
328 let lhs_grad = grad.matmul(&rhs.t()?)?;
332 let lhs_sum_grad = grads.or_insert(lhs)?;
333 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
334
335 let rhs_grad = lhs.t()?.matmul(&grad)?;
336 let rhs_sum_grad = grads.or_insert(rhs)?;
337 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
338 }
339 Op::Cat(args, dim) => {
340 let mut start_idx = 0;
341 for arg in args {
342 let len = arg.dims()[*dim];
343 let arg_grad = grad.narrow(*dim, start_idx, len)?;
344 let sum_grad = grads.or_insert(arg)?;
345 *sum_grad = sum_grad.add(&arg_grad)?;
346 start_idx += len;
347 }
348 }
349 Op::Broadcast(arg) => {
350 let arg_dims = arg.dims();
351 let node_dims = node.dims();
352 let left_dims = node_dims.len() - arg_dims.len();
354 let mut sum_dims: Vec<usize> = (0..left_dims).collect();
355 for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
356 .iter()
357 .zip(arg_dims.iter())
358 .enumerate()
359 {
360 if node_dim != arg_dim {
361 sum_dims.push(dim + left_dims)
362 }
363 }
364
365 let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
366 for _i in 0..left_dims {
367 arg_grad = arg_grad.squeeze(0)?
368 }
369 let sum_grad = grads.or_insert(arg)?;
370 *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
371 }
372 Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
373 let grad = broadcast_back(arg, &grad, reduced_dims)?;
374 let sum_grad = grads.or_insert(arg)?;
375 *sum_grad = sum_grad.add(&grad)?;
376 }
377 Op::Cmp(_args, _) => {}
378 Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
379 let node = broadcast_back(arg, node, reduced_dims)?;
380 let grad = broadcast_back(arg, &grad, reduced_dims)?;
381 let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
382 let sum_grad = grads.or_insert(arg)?;
383 *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
384 }
385 Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
386 let node = broadcast_back(arg, node, reduced_dims)?;
387 let grad = broadcast_back(arg, &grad, reduced_dims)?;
388 let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
389 let sum_grad = grads.or_insert(arg)?;
390 *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
391 }
392 Op::ToDType(arg) => {
393 let sum_grad = grads.or_insert(arg)?;
394 *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
395 }
396 Op::Copy(arg) => {
397 let sum_grad = grads.or_insert(arg)?;
398 *sum_grad = sum_grad.add(&grad)?
399 }
400 Op::Affine { arg, mul, .. } => {
401 let arg_grad = grad.affine(*mul, 0.)?;
402 let sum_grad = grads.or_insert(arg)?;
403 *sum_grad = sum_grad.add(&arg_grad)?
404 }
405 Op::Unary(arg, UnaryOp::Log) => {
406 let sum_grad = grads.or_insert(arg)?;
407 *sum_grad = sum_grad.add(&(grad / arg)?)?
408 }
409 Op::Unary(arg, UnaryOp::Sin) => {
410 let sum_grad = grads.or_insert(arg)?;
411 *sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
412 }
413 Op::Unary(arg, UnaryOp::Cos) => {
414 let sum_grad = grads.or_insert(arg)?;
415 *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
416 }
417 Op::Unary(arg, UnaryOp::Tanh) => {
418 let sum_grad = grads.or_insert(arg)?;
419 let minus_dtanh = (node.sqr()? - 1.)?;
420 *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
421 }
422 Op::Unary(arg, UnaryOp::Abs) => {
423 let sum_grad = grads.or_insert(arg)?;
424 let ones = arg.ones_like()?;
425 let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
426 *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
427 }
428 Op::Unary(arg, UnaryOp::Exp) => {
429 let sum_grad = grads.or_insert(arg)?;
430 *sum_grad = sum_grad.add(&(&grad * *node)?)?
431 }
432 Op::Unary(arg, UnaryOp::Neg) => {
433 let sum_grad = grads.or_insert(arg)?;
434 *sum_grad = sum_grad.sub(&grad)?
435 }
436 Op::Unary(arg, UnaryOp::Recip) => {
437 let sum_grad = grads.or_insert(arg)?;
438 let grad = (grad / arg.sqr()?)?;
439 *sum_grad = sum_grad.sub(&grad)?
440 }
441 &Op::Narrow(ref arg, dim, start_idx, len) => {
442 let arg_dims = arg.dims();
443 let left_pad = if start_idx == 0 {
444 None
445 } else {
446 let mut dims = arg_dims.to_vec();
447 dims[dim] = start_idx;
448 Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
449 };
450 let right_pad = arg_dims[dim] - start_idx - len;
451 let right_pad = if right_pad == 0 {
452 None
453 } else {
454 let mut dims = arg_dims.to_vec();
455 dims[dim] = right_pad;
456 Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
457 };
458 let arg_grad = match (left_pad, right_pad) {
459 (None, None) => grad,
460 (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
461 (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
462 (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
463 };
464 let sum_grad = grads.or_insert(arg)?;
465 *sum_grad = sum_grad.add(&arg_grad)?
466 }
467 Op::Reduce(_, ReduceOp::ArgMin, _) => {}
468 Op::Reduce(_, ReduceOp::ArgMax, _) => {}
469 Op::Reshape(arg) => {
470 let arg_grad = grad.reshape(arg.dims())?;
471 let sum_grad = grads.or_insert(arg)?;
472 *sum_grad = sum_grad.add(&arg_grad)?
473 }
474 Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
475 Op::Unary(_, UnaryOp::Floor) => {
476 Err(Error::BackwardNotSupported { op: "floor" })?
477 }
478 Op::Unary(_, UnaryOp::Round) => {
479 Err(Error::BackwardNotSupported { op: "round" })?
480 }
481 Op::Unary(arg, UnaryOp::Gelu) => {
482 let sum_grad = grads.or_insert(arg)?;
483 let cube = arg.powf(3.)?;
484 let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
485 let gelu_grad = (((0.5 * &tanh)?
486 + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
487 + 0.5)?;
488 *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
489 }
490 Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
491 Op::Unary(_, UnaryOp::GeluErf) => {
492 Err(Error::BackwardNotSupported { op: "gelu-erf" })?
493 }
494 Op::Unary(arg, UnaryOp::Relu) => {
495 let sum_grad = grads.or_insert(arg)?;
496 let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
497 *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
498 }
499 Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
500 Op::Powf(arg, e) => {
501 let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
502 let sum_grad = grads.or_insert(arg)?;
503 *sum_grad = sum_grad.add(&arg_grad)?
504 }
505 Op::CustomOp1(arg, c) => {
506 if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
507 let sum_grad = grads.or_insert(arg)?;
508 *sum_grad = sum_grad.add(&arg_grad)?
509 }
510 }
511 Op::CustomOp2(arg1, arg2, c) => {
512 let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
513 if let Some(arg_grad1) = arg_grad1 {
514 let sum_grad = grads.or_insert(arg1)?;
515 *sum_grad = sum_grad.add(&arg_grad1)?
516 }
517 if let Some(arg_grad2) = arg_grad2 {
518 let sum_grad = grads.or_insert(arg2)?;
519 *sum_grad = sum_grad.add(&arg_grad2)?
520 }
521 }
522 Op::CustomOp3(arg1, arg2, arg3, c) => {
523 let (arg_grad1, arg_grad2, arg_grad3) =
524 c.bwd(arg1, arg2, arg3, node, &grad)?;
525 if let Some(arg_grad1) = arg_grad1 {
526 let sum_grad = grads.or_insert(arg1)?;
527 *sum_grad = sum_grad.add(&arg_grad1)?
528 }
529 if let Some(arg_grad2) = arg_grad2 {
530 let sum_grad = grads.or_insert(arg2)?;
531 *sum_grad = sum_grad.add(&arg_grad2)?
532 }
533 if let Some(arg_grad3) = arg_grad3 {
534 let sum_grad = grads.or_insert(arg3)?;
535 *sum_grad = sum_grad.add(&arg_grad3)?
536 }
537 }
538 Op::Unary(arg, UnaryOp::Sqr) => {
539 let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
540 let sum_grad = grads.or_insert(arg)?;
541 *sum_grad = sum_grad.add(&arg_grad)?
542 }
543 Op::Unary(arg, UnaryOp::Sqrt) => {
544 let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
545 let sum_grad = grads.or_insert(arg)?;
546 *sum_grad = sum_grad.add(&arg_grad)?
547 }
548 Op::ToDevice(arg) => {
549 let sum_grad = grads.or_insert(arg)?;
550 let arg_grad = grad.to_device(sum_grad.device())?;
551 *sum_grad = sum_grad.add(&arg_grad)?
552 }
553 Op::Transpose(arg, dim1, dim2) => {
554 let arg_grad = grad.transpose(*dim1, *dim2)?;
555 let sum_grad = grads.or_insert(arg)?;
556 *sum_grad = sum_grad.add(&arg_grad)?
557 }
558 Op::Permute(arg, dims) => {
559 let mut inv_dims = vec![0; dims.len()];
560 for (i, &dim_idx) in dims.iter().enumerate() {
561 inv_dims[dim_idx] = i
562 }
563 let arg_grad = grad.permute(inv_dims)?;
564 let sum_grad = grads.or_insert(arg)?;
565 *sum_grad = sum_grad.add(&arg_grad)?
566 }
567 };
568 }
569 }
570 Ok(grads)
571 }
572}
573
574#[derive(Debug)]
575pub struct GradStore(HashMap<TensorId, Tensor>);
576
577impl GradStore {
578 fn new() -> Self {
579 GradStore(HashMap::new())
580 }
581
582 pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
583 self.0.get(&id)
584 }
585
586 pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
587 self.0.get(&tensor.id())
588 }
589
590 pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
591 self.0.remove(&tensor.id())
592 }
593
594 pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
595 self.0.insert(tensor.id(), grad)
596 }
597
598 fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
599 use std::collections::hash_map::Entry;
600 let grad = match self.0.entry(tensor.id()) {
601 Entry::Occupied(entry) => entry.into_mut(),
602 Entry::Vacant(entry) => {
603 let grad = tensor.zeros_like()?;
604 entry.insert(grad)
605 }
606 };
607 Ok(grad)
608 }
609}