candle_core_temp/
backprop.rs

1use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
2use crate::{Error, Result, Tensor, TensorId};
3use std::collections::HashMap;
4
5// arg has been reduced to node via reduce_dims, expand it back to arg.
6// This has to handle keepdims.
7fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
8    if arg.rank() == node.rank() {
9        // keepdim = true
10        node.broadcast_as(arg.shape())
11    } else {
12        // keepdim = false
13        // first expand the reduced dims.
14        node.reshape(reduced_dims)?.broadcast_as(arg.shape())
15    }
16}
17
18impl Tensor {
19    /// Return all the nodes that lead to this value in a topologically sorted vec, the first
20    /// elements having dependencies on the latter ones, e.g. the first element if any is the
21    /// argument.
22    /// This assumes that the op graph is a DAG.
23    fn sorted_nodes(&self) -> Vec<&Tensor> {
24        // The vec of sorted nodes is passed as an owned value rather than a mutable reference
25        // to get around some lifetime limitations.
26        fn walk<'a>(
27            node: &'a Tensor,
28            nodes: Vec<&'a Tensor>,
29            already_seen: &mut HashMap<TensorId, bool>,
30        ) -> (bool, Vec<&'a Tensor>) {
31            if let Some(&tg) = already_seen.get(&node.id()) {
32                return (tg, nodes);
33            }
34            let mut track_grad = false;
35            let mut nodes = if node.is_variable() {
36                // Do not call recursively on the "leaf" nodes.
37                track_grad = true;
38                nodes
39            } else if node.dtype().is_int() {
40                nodes
41            } else if let Some(op) = node.op() {
42                match op {
43                    Op::IndexAdd(t1, t2, t3, _)
44                    | Op::ScatterAdd(t1, t2, t3, _)
45                    | Op::CustomOp3(t1, t2, t3, _)
46                    | Op::WhereCond(t1, t2, t3) => {
47                        let (tg, nodes) = walk(t1, nodes, already_seen);
48                        track_grad |= tg;
49                        let (tg, nodes) = walk(t2, nodes, already_seen);
50                        track_grad |= tg;
51                        let (tg, nodes) = walk(t3, nodes, already_seen);
52                        track_grad |= tg;
53                        nodes
54                    }
55                    Op::Conv1D {
56                        arg: lhs,
57                        kernel: rhs,
58                        ..
59                    }
60                    | Op::Conv2D {
61                        arg: lhs,
62                        kernel: rhs,
63                        ..
64                    }
65                    | Op::ConvTranspose2D {
66                        arg: lhs,
67                        kernel: rhs,
68                        ..
69                    }
70                    | Op::CustomOp2(lhs, rhs, _)
71                    | Op::Binary(lhs, rhs, _)
72                    | Op::Gather(lhs, rhs, _)
73                    | Op::IndexSelect(lhs, rhs, _)
74                    | Op::Matmul(lhs, rhs)
75                    | Op::SliceScatter0(lhs, rhs, _) => {
76                        let (tg, nodes) = walk(lhs, nodes, already_seen);
77                        track_grad |= tg;
78                        let (tg, nodes) = walk(rhs, nodes, already_seen);
79                        track_grad |= tg;
80                        nodes
81                    }
82                    Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
83                        let (tg, nodes) = walk(arg, nodes, already_seen);
84                        track_grad |= tg;
85                        nodes
86                    }),
87                    Op::Affine { arg, mul, .. } => {
88                        if *mul == 0. {
89                            nodes
90                        } else {
91                            let (tg, nodes) = walk(arg, nodes, already_seen);
92                            track_grad |= tg;
93                            nodes
94                        }
95                    }
96                    Op::Unary(_node, UnaryOp::Ceil)
97                    | Op::Unary(_node, UnaryOp::Floor)
98                    | Op::Unary(_node, UnaryOp::Round) => nodes,
99                    Op::Reshape(node)
100                    | Op::UpsampleNearest1D(node)
101                    | Op::UpsampleNearest2D(node)
102                    | Op::AvgPool2D { arg: node, .. }
103                    | Op::MaxPool2D { arg: node, .. }
104                    | Op::Copy(node)
105                    | Op::Broadcast(node)
106                    | Op::Cmp(node, _)
107                    | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
108                    | Op::ToDevice(node)
109                    | Op::Transpose(node, _, _)
110                    | Op::Permute(node, _)
111                    | Op::Narrow(node, _, _, _)
112                    | Op::Unary(node, _)
113                    | Op::Elu(node, _)
114                    | Op::Powf(node, _)
115                    | Op::CustomOp1(node, _) => {
116                        let (tg, nodes) = walk(node, nodes, already_seen);
117                        track_grad |= tg;
118                        nodes
119                    }
120                    Op::ToDType(node) => {
121                        if node.dtype().is_float() {
122                            let (tg, nodes) = walk(node, nodes, already_seen);
123                            track_grad |= tg;
124                            nodes
125                        } else {
126                            nodes
127                        }
128                    }
129                    Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
130                }
131            } else {
132                nodes
133            };
134            already_seen.insert(node.id(), track_grad);
135            if track_grad {
136                nodes.push(node);
137            }
138            (track_grad, nodes)
139        }
140        let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
141        nodes.reverse();
142        nodes
143    }
144
145    pub fn backward(&self) -> Result<GradStore> {
146        let sorted_nodes = self.sorted_nodes();
147        let mut grads = GradStore::new();
148        grads.insert(self, self.ones_like()?.contiguous()?);
149        for node in sorted_nodes.iter() {
150            if node.is_variable() {
151                continue;
152            }
153            let grad = grads.remove(node).unwrap();
154            // TODO: We should perform all these operations in place (or at least not track the
155            // whole graph). The only drawback would be if we wanted to support grad of grad but
156            // this is out of scope.
157            if let Some(op) = node.op() {
158                match op {
159                    Op::Binary(lhs, rhs, BinaryOp::Add) => {
160                        let lhs_sum_grad = grads.or_insert(lhs)?;
161                        *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
162                        let rhs_sum_grad = grads.or_insert(rhs)?;
163                        *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
164                    }
165                    Op::Binary(lhs, rhs, BinaryOp::Sub) => {
166                        let lhs_sum_grad = grads.or_insert(lhs)?;
167                        *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
168                        let rhs_sum_grad = grads.or_insert(rhs)?;
169                        *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
170                    }
171                    Op::Binary(lhs, rhs, BinaryOp::Mul) => {
172                        let lhs_grad = grad.mul(rhs)?;
173                        let lhs_sum_grad = grads.or_insert(lhs)?;
174                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
175                        let rhs_grad = grad.mul(lhs)?;
176                        let rhs_sum_grad = grads.or_insert(rhs)?;
177                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
178                    }
179                    Op::Binary(lhs, rhs, BinaryOp::Div) => {
180                        let lhs_grad = grad.div(rhs)?;
181                        let lhs_sum_grad = grads.or_insert(lhs)?;
182                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
183                        let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
184                        let rhs_sum_grad = grads.or_insert(rhs)?;
185                        *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
186                    }
187                    Op::Binary(lhs, rhs, BinaryOp::Minimum)
188                    | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
189                        let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
190                        let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
191
192                        // If both masks are 1 one the same point, we want to scale the
193                        // gradient by 0.5 rather than 1.
194                        let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
195                        let lhs_sum_grad = grads.or_insert(lhs)?;
196                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
197
198                        let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
199                        let rhs_sum_grad = grads.or_insert(rhs)?;
200                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
201                    }
202                    Op::WhereCond(pred, t, f) => {
203                        let zeros = grad.zeros_like()?;
204                        let t_sum_grad = grads.or_insert(t)?;
205                        let t_grad = pred.where_cond(&grad, &zeros)?;
206                        *t_sum_grad = t_sum_grad.add(&t_grad)?;
207                        let f_sum_grad = grads.or_insert(f)?;
208                        let f_grad = pred.where_cond(&zeros, &grad)?;
209                        *f_sum_grad = f_sum_grad.add(&f_grad)?;
210                    }
211                    Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
212                    Op::Conv2D {
213                        arg,
214                        kernel,
215                        padding,
216                        stride,
217                        dilation,
218                    } => {
219                        // The output height for conv_transpose2d is:
220                        // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
221                        let grad_h = grad.dim(2)?;
222                        let k_h = kernel.dim(2)?;
223                        let out_size =
224                            (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
225                        let out_padding = arg.dim(2)? - out_size;
226                        let grad_arg = grad.conv_transpose2d(
227                            kernel,
228                            *padding,
229                            out_padding,
230                            *stride,
231                            *dilation,
232                        )?;
233                        let sum_grad = grads.or_insert(arg)?;
234                        *sum_grad = sum_grad.add(&grad_arg)?;
235
236                        let grad_kernel = arg
237                            .transpose(0, 1)?
238                            .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
239                            .transpose(0, 1)?;
240                        let sum_grad = grads.or_insert(kernel)?;
241                        let (_, _, k0, k1) = kernel.dims4()?;
242                        let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
243                        let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
244                            grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
245                        } else {
246                            grad_kernel
247                        };
248                        *sum_grad = sum_grad.add(&grad_kernel)?;
249                    }
250                    Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
251                        op: "conv-transpose2d",
252                    })?,
253                    Op::AvgPool2D {
254                        arg,
255                        kernel_size,
256                        stride,
257                    } => {
258                        if kernel_size != stride {
259                            crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
260                        }
261                        let (_n, _c, h, w) = arg.dims4()?;
262                        let grad_arg = grad.upsample_nearest2d(h, w)?;
263                        let grad_arg =
264                            (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
265                        let sum_grad = grads.or_insert(arg)?;
266                        *sum_grad = sum_grad.add(&grad_arg)?;
267                    }
268                    Op::MaxPool2D {
269                        arg,
270                        kernel_size,
271                        stride,
272                    } => {
273                        if kernel_size != stride {
274                            crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
275                        }
276                        let (_n, _c, h, w) = arg.dims4()?;
277                        // For computing the max-pool gradient, we compute a mask where a 1 means
278                        // that the element is the maximum, then we apply this mask to the
279                        // upsampled gradient (taking into account that multiple max may exist so
280                        // we scale the gradient for this case).
281                        let node_upsampled = node.upsample_nearest2d(h, w)?;
282                        let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
283                        let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
284                        let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
285                        let sum_grad = grads.or_insert(arg)?;
286                        *sum_grad = sum_grad.add(&grad_arg)?;
287                    }
288                    Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
289                        op: "upsample-nearest1d",
290                    })?,
291                    Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
292                        op: "upsample-nearest2d",
293                    })?,
294                    Op::SliceScatter0(lhs, rhs, start_rhs) => {
295                        let rhs_sum_grad = grads.or_insert(rhs)?;
296                        let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
297                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
298
299                        let lhs_sum_grad = grads.or_insert(lhs)?;
300                        let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
301                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
302                    }
303                    Op::Gather(arg, indexes, dim) => {
304                        let sum_grad = grads.or_insert(arg)?;
305                        *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
306                    }
307                    Op::ScatterAdd(init, indexes, src, dim) => {
308                        let init_sum_grad = grads.or_insert(init)?;
309                        *init_sum_grad = init_sum_grad.add(&grad)?;
310
311                        let src_grad = grad.gather(indexes, *dim)?;
312                        let src_sum_grad = grads.or_insert(src)?;
313                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
314                    }
315                    Op::IndexAdd(init, indexes, src, dim) => {
316                        let init_sum_grad = grads.or_insert(init)?;
317                        *init_sum_grad = init_sum_grad.add(&grad)?;
318
319                        let src_grad = grad.index_select(indexes, *dim)?;
320                        let src_sum_grad = grads.or_insert(src)?;
321                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
322                    }
323                    Op::IndexSelect(arg, indexes, dim) => {
324                        let sum_grad = grads.or_insert(arg)?;
325                        *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
326                    }
327                    Op::Matmul(lhs, rhs) => {
328                        // Skipping checks, the op went ok, we can skip
329                        // the matmul size checks for now.
330
331                        let lhs_grad = grad.matmul(&rhs.t()?)?;
332                        let lhs_sum_grad = grads.or_insert(lhs)?;
333                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
334
335                        let rhs_grad = lhs.t()?.matmul(&grad)?;
336                        let rhs_sum_grad = grads.or_insert(rhs)?;
337                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
338                    }
339                    Op::Cat(args, dim) => {
340                        let mut start_idx = 0;
341                        for arg in args {
342                            let len = arg.dims()[*dim];
343                            let arg_grad = grad.narrow(*dim, start_idx, len)?;
344                            let sum_grad = grads.or_insert(arg)?;
345                            *sum_grad = sum_grad.add(&arg_grad)?;
346                            start_idx += len;
347                        }
348                    }
349                    Op::Broadcast(arg) => {
350                        let arg_dims = arg.dims();
351                        let node_dims = node.dims();
352                        // The number of dims that have been inserted on the left.
353                        let left_dims = node_dims.len() - arg_dims.len();
354                        let mut sum_dims: Vec<usize> = (0..left_dims).collect();
355                        for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
356                            .iter()
357                            .zip(arg_dims.iter())
358                            .enumerate()
359                        {
360                            if node_dim != arg_dim {
361                                sum_dims.push(dim + left_dims)
362                            }
363                        }
364
365                        let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
366                        for _i in 0..left_dims {
367                            arg_grad = arg_grad.squeeze(0)?
368                        }
369                        let sum_grad = grads.or_insert(arg)?;
370                        *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
371                    }
372                    Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
373                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
374                        let sum_grad = grads.or_insert(arg)?;
375                        *sum_grad = sum_grad.add(&grad)?;
376                    }
377                    Op::Cmp(_args, _) => {}
378                    Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
379                        let node = broadcast_back(arg, node, reduced_dims)?;
380                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
381                        let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
382                        let sum_grad = grads.or_insert(arg)?;
383                        *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
384                    }
385                    Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
386                        let node = broadcast_back(arg, node, reduced_dims)?;
387                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
388                        let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
389                        let sum_grad = grads.or_insert(arg)?;
390                        *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
391                    }
392                    Op::ToDType(arg) => {
393                        let sum_grad = grads.or_insert(arg)?;
394                        *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
395                    }
396                    Op::Copy(arg) => {
397                        let sum_grad = grads.or_insert(arg)?;
398                        *sum_grad = sum_grad.add(&grad)?
399                    }
400                    Op::Affine { arg, mul, .. } => {
401                        let arg_grad = grad.affine(*mul, 0.)?;
402                        let sum_grad = grads.or_insert(arg)?;
403                        *sum_grad = sum_grad.add(&arg_grad)?
404                    }
405                    Op::Unary(arg, UnaryOp::Log) => {
406                        let sum_grad = grads.or_insert(arg)?;
407                        *sum_grad = sum_grad.add(&(grad / arg)?)?
408                    }
409                    Op::Unary(arg, UnaryOp::Sin) => {
410                        let sum_grad = grads.or_insert(arg)?;
411                        *sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
412                    }
413                    Op::Unary(arg, UnaryOp::Cos) => {
414                        let sum_grad = grads.or_insert(arg)?;
415                        *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
416                    }
417                    Op::Unary(arg, UnaryOp::Tanh) => {
418                        let sum_grad = grads.or_insert(arg)?;
419                        let minus_dtanh = (node.sqr()? - 1.)?;
420                        *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
421                    }
422                    Op::Unary(arg, UnaryOp::Abs) => {
423                        let sum_grad = grads.or_insert(arg)?;
424                        let ones = arg.ones_like()?;
425                        let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
426                        *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
427                    }
428                    Op::Unary(arg, UnaryOp::Exp) => {
429                        let sum_grad = grads.or_insert(arg)?;
430                        *sum_grad = sum_grad.add(&(&grad * *node)?)?
431                    }
432                    Op::Unary(arg, UnaryOp::Neg) => {
433                        let sum_grad = grads.or_insert(arg)?;
434                        *sum_grad = sum_grad.sub(&grad)?
435                    }
436                    Op::Unary(arg, UnaryOp::Recip) => {
437                        let sum_grad = grads.or_insert(arg)?;
438                        let grad = (grad / arg.sqr()?)?;
439                        *sum_grad = sum_grad.sub(&grad)?
440                    }
441                    &Op::Narrow(ref arg, dim, start_idx, len) => {
442                        let arg_dims = arg.dims();
443                        let left_pad = if start_idx == 0 {
444                            None
445                        } else {
446                            let mut dims = arg_dims.to_vec();
447                            dims[dim] = start_idx;
448                            Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
449                        };
450                        let right_pad = arg_dims[dim] - start_idx - len;
451                        let right_pad = if right_pad == 0 {
452                            None
453                        } else {
454                            let mut dims = arg_dims.to_vec();
455                            dims[dim] = right_pad;
456                            Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
457                        };
458                        let arg_grad = match (left_pad, right_pad) {
459                            (None, None) => grad,
460                            (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
461                            (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
462                            (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
463                        };
464                        let sum_grad = grads.or_insert(arg)?;
465                        *sum_grad = sum_grad.add(&arg_grad)?
466                    }
467                    Op::Reduce(_, ReduceOp::ArgMin, _) => {}
468                    Op::Reduce(_, ReduceOp::ArgMax, _) => {}
469                    Op::Reshape(arg) => {
470                        let arg_grad = grad.reshape(arg.dims())?;
471                        let sum_grad = grads.or_insert(arg)?;
472                        *sum_grad = sum_grad.add(&arg_grad)?
473                    }
474                    Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
475                    Op::Unary(_, UnaryOp::Floor) => {
476                        Err(Error::BackwardNotSupported { op: "floor" })?
477                    }
478                    Op::Unary(_, UnaryOp::Round) => {
479                        Err(Error::BackwardNotSupported { op: "round" })?
480                    }
481                    Op::Unary(arg, UnaryOp::Gelu) => {
482                        let sum_grad = grads.or_insert(arg)?;
483                        let cube = arg.powf(3.)?;
484                        let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
485                        let gelu_grad = (((0.5 * &tanh)?
486                            + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
487                            + 0.5)?;
488                        *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
489                    }
490                    Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
491                    Op::Unary(_, UnaryOp::GeluErf) => {
492                        Err(Error::BackwardNotSupported { op: "gelu-erf" })?
493                    }
494                    Op::Unary(arg, UnaryOp::Relu) => {
495                        let sum_grad = grads.or_insert(arg)?;
496                        let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
497                        *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
498                    }
499                    Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
500                    Op::Powf(arg, e) => {
501                        let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
502                        let sum_grad = grads.or_insert(arg)?;
503                        *sum_grad = sum_grad.add(&arg_grad)?
504                    }
505                    Op::CustomOp1(arg, c) => {
506                        if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
507                            let sum_grad = grads.or_insert(arg)?;
508                            *sum_grad = sum_grad.add(&arg_grad)?
509                        }
510                    }
511                    Op::CustomOp2(arg1, arg2, c) => {
512                        let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
513                        if let Some(arg_grad1) = arg_grad1 {
514                            let sum_grad = grads.or_insert(arg1)?;
515                            *sum_grad = sum_grad.add(&arg_grad1)?
516                        }
517                        if let Some(arg_grad2) = arg_grad2 {
518                            let sum_grad = grads.or_insert(arg2)?;
519                            *sum_grad = sum_grad.add(&arg_grad2)?
520                        }
521                    }
522                    Op::CustomOp3(arg1, arg2, arg3, c) => {
523                        let (arg_grad1, arg_grad2, arg_grad3) =
524                            c.bwd(arg1, arg2, arg3, node, &grad)?;
525                        if let Some(arg_grad1) = arg_grad1 {
526                            let sum_grad = grads.or_insert(arg1)?;
527                            *sum_grad = sum_grad.add(&arg_grad1)?
528                        }
529                        if let Some(arg_grad2) = arg_grad2 {
530                            let sum_grad = grads.or_insert(arg2)?;
531                            *sum_grad = sum_grad.add(&arg_grad2)?
532                        }
533                        if let Some(arg_grad3) = arg_grad3 {
534                            let sum_grad = grads.or_insert(arg3)?;
535                            *sum_grad = sum_grad.add(&arg_grad3)?
536                        }
537                    }
538                    Op::Unary(arg, UnaryOp::Sqr) => {
539                        let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
540                        let sum_grad = grads.or_insert(arg)?;
541                        *sum_grad = sum_grad.add(&arg_grad)?
542                    }
543                    Op::Unary(arg, UnaryOp::Sqrt) => {
544                        let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
545                        let sum_grad = grads.or_insert(arg)?;
546                        *sum_grad = sum_grad.add(&arg_grad)?
547                    }
548                    Op::ToDevice(arg) => {
549                        let sum_grad = grads.or_insert(arg)?;
550                        let arg_grad = grad.to_device(sum_grad.device())?;
551                        *sum_grad = sum_grad.add(&arg_grad)?
552                    }
553                    Op::Transpose(arg, dim1, dim2) => {
554                        let arg_grad = grad.transpose(*dim1, *dim2)?;
555                        let sum_grad = grads.or_insert(arg)?;
556                        *sum_grad = sum_grad.add(&arg_grad)?
557                    }
558                    Op::Permute(arg, dims) => {
559                        let mut inv_dims = vec![0; dims.len()];
560                        for (i, &dim_idx) in dims.iter().enumerate() {
561                            inv_dims[dim_idx] = i
562                        }
563                        let arg_grad = grad.permute(inv_dims)?;
564                        let sum_grad = grads.or_insert(arg)?;
565                        *sum_grad = sum_grad.add(&arg_grad)?
566                    }
567                };
568            }
569        }
570        Ok(grads)
571    }
572}
573
574#[derive(Debug)]
575pub struct GradStore(HashMap<TensorId, Tensor>);
576
577impl GradStore {
578    fn new() -> Self {
579        GradStore(HashMap::new())
580    }
581
582    pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
583        self.0.get(&id)
584    }
585
586    pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
587        self.0.get(&tensor.id())
588    }
589
590    pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
591        self.0.remove(&tensor.id())
592    }
593
594    pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
595        self.0.insert(tensor.id(), grad)
596    }
597
598    fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
599        use std::collections::hash_map::Entry;
600        let grad = match self.0.entry(tensor.id()) {
601            Entry::Occupied(entry) => entry.into_mut(),
602            Entry::Vacant(entry) => {
603                let grad = tensor.zeros_like()?;
604                entry.insert(grad)
605            }
606        };
607        Ok(grad)
608    }
609}