candle_core/
backprop.rs

1//! Methods for backpropagation of gradients.
2use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
3use crate::{Error, Result, Tensor, TensorId};
4use std::collections::HashMap;
5
6// arg has been reduced to node via reduce_dims, expand it back to arg.
7// This has to handle keepdims.
8fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
9    if arg.rank() == node.rank() {
10        // keepdim = true
11        node.broadcast_as(arg.shape())
12    } else {
13        // keepdim = false
14        // first expand the reduced dims.
15        node.reshape(reduced_dims)?.broadcast_as(arg.shape())
16    }
17}
18
19thread_local! {
20    static CANDLE_GRAD_DO_NOT_DETACH: bool = {
21        match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
22            Ok(s) => {
23                !s.is_empty() && s != "0"
24            },
25            Err(_) => false,
26        }
27    }
28}
29
30impl Tensor {
31    /// Return all the nodes that lead to this value in a topologically sorted vec, the first
32    /// elements having dependencies on the latter ones, e.g. the first element if any is the
33    /// argument.
34    /// This assumes that the op graph is a DAG.
35    fn sorted_nodes(&self) -> Vec<&Tensor> {
36        // The vec of sorted nodes is passed as an owned value rather than a mutable reference
37        // to get around some lifetime limitations.
38        fn walk<'a>(
39            node: &'a Tensor,
40            nodes: Vec<&'a Tensor>,
41            already_seen: &mut HashMap<TensorId, bool>,
42        ) -> (bool, Vec<&'a Tensor>) {
43            if let Some(&tg) = already_seen.get(&node.id()) {
44                return (tg, nodes);
45            }
46            let mut track_grad = false;
47            let mut nodes = if node.is_variable() {
48                // Do not call recursively on the "leaf" nodes.
49                track_grad = true;
50                nodes
51            } else if node.dtype().is_int() {
52                nodes
53            } else if let Some(op) = node.op() {
54                match op {
55                    Op::IndexAdd(t1, t2, t3, _)
56                    | Op::ScatterAdd(t1, t2, t3, _)
57                    | Op::CustomOp3(t1, t2, t3, _)
58                    | Op::WhereCond(t1, t2, t3) => {
59                        let (tg, nodes) = walk(t1, nodes, already_seen);
60                        track_grad |= tg;
61                        let (tg, nodes) = walk(t2, nodes, already_seen);
62                        track_grad |= tg;
63                        let (tg, nodes) = walk(t3, nodes, already_seen);
64                        track_grad |= tg;
65                        nodes
66                    }
67                    Op::Conv1D {
68                        arg: lhs,
69                        kernel: rhs,
70                        ..
71                    }
72                    | Op::ConvTranspose1D {
73                        arg: lhs,
74                        kernel: rhs,
75                        ..
76                    }
77                    | Op::Conv2D {
78                        arg: lhs,
79                        kernel: rhs,
80                        ..
81                    }
82                    | Op::ConvTranspose2D {
83                        arg: lhs,
84                        kernel: rhs,
85                        ..
86                    }
87                    | Op::CustomOp2(lhs, rhs, _)
88                    | Op::Binary(lhs, rhs, _)
89                    | Op::Gather(lhs, rhs, _)
90                    | Op::IndexSelect(lhs, rhs, _)
91                    | Op::Matmul(lhs, rhs)
92                    | Op::SliceScatter0(lhs, rhs, _) => {
93                        let (tg, nodes) = walk(lhs, nodes, already_seen);
94                        track_grad |= tg;
95                        let (tg, nodes) = walk(rhs, nodes, already_seen);
96                        track_grad |= tg;
97                        nodes
98                    }
99                    Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
100                        let (tg, nodes) = walk(arg, nodes, already_seen);
101                        track_grad |= tg;
102                        nodes
103                    }),
104                    Op::Affine { arg, mul, .. } => {
105                        if *mul == 0. {
106                            nodes
107                        } else {
108                            let (tg, nodes) = walk(arg, nodes, already_seen);
109                            track_grad |= tg;
110                            nodes
111                        }
112                    }
113                    Op::Unary(_node, UnaryOp::Ceil)
114                    | Op::Unary(_node, UnaryOp::Floor)
115                    | Op::Unary(_node, UnaryOp::Round)
116                    | Op::Unary(_node, UnaryOp::Sign) => nodes,
117                    Op::Reshape(node)
118                    | Op::UpsampleNearest1D { arg: node, .. }
119                    | Op::UpsampleNearest2D { arg: node, .. }
120                    | Op::AvgPool2D { arg: node, .. }
121                    | Op::MaxPool2D { arg: node, .. }
122                    | Op::Copy(node)
123                    | Op::Broadcast(node)
124                    | Op::Cmp(node, _)
125                    | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
126                    | Op::ToDevice(node)
127                    | Op::Transpose(node, _, _)
128                    | Op::Permute(node, _)
129                    | Op::Narrow(node, _, _, _)
130                    | Op::Unary(node, _)
131                    | Op::Elu(node, _)
132                    | Op::Powf(node, _)
133                    | Op::CustomOp1(node, _) => {
134                        let (tg, nodes) = walk(node, nodes, already_seen);
135                        track_grad |= tg;
136                        nodes
137                    }
138                    Op::ToDType(node) => {
139                        if node.dtype().is_float() {
140                            let (tg, nodes) = walk(node, nodes, already_seen);
141                            track_grad |= tg;
142                            nodes
143                        } else {
144                            nodes
145                        }
146                    }
147                    Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
148                }
149            } else {
150                nodes
151            };
152            already_seen.insert(node.id(), track_grad);
153            if track_grad {
154                nodes.push(node);
155            }
156            (track_grad, nodes)
157        }
158        let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
159        nodes.reverse();
160        nodes
161    }
162
163    pub fn backward(&self) -> Result<GradStore> {
164        let sorted_nodes = self.sorted_nodes();
165        let mut grads = GradStore::new();
166        grads.insert(self, self.ones_like()?.contiguous()?);
167        for node in sorted_nodes.iter() {
168            if node.is_variable() {
169                continue;
170            }
171            let grad = grads
172                .remove(node)
173                .expect("candle internal error - grad not populated");
174            // https://github.com/huggingface/candle/issues/1241
175            // Ideally, we would make these operations in place where possible to ensure that we
176            // do not have to allocate too often. Here we just call `.detach` to avoid computing
177            // the backprop graph of the backprop itself. This would be an issue for second order
178            // derivatives but these are out of scope at the moment.
179            let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
180            let grad = if do_not_detach { grad } else { grad.detach() };
181            if let Some(op) = node.op() {
182                match op {
183                    Op::Binary(lhs, rhs, BinaryOp::Add) => {
184                        let lhs_sum_grad = grads.or_insert(lhs)?;
185                        *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
186                        let rhs_sum_grad = grads.or_insert(rhs)?;
187                        *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
188                    }
189                    Op::Binary(lhs, rhs, BinaryOp::Sub) => {
190                        let lhs_sum_grad = grads.or_insert(lhs)?;
191                        *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
192                        let rhs_sum_grad = grads.or_insert(rhs)?;
193                        *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
194                    }
195                    Op::Binary(lhs, rhs, BinaryOp::Mul) => {
196                        let lhs_grad = grad.mul(rhs)?;
197                        let lhs_sum_grad = grads.or_insert(lhs)?;
198                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
199                        let rhs_grad = grad.mul(lhs)?;
200                        let rhs_sum_grad = grads.or_insert(rhs)?;
201                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
202                    }
203                    Op::Binary(lhs, rhs, BinaryOp::Div) => {
204                        let lhs_grad = grad.div(rhs)?;
205                        let lhs_sum_grad = grads.or_insert(lhs)?;
206                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
207                        let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
208                        let rhs_sum_grad = grads.or_insert(rhs)?;
209                        *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
210                    }
211                    Op::Binary(lhs, rhs, BinaryOp::Minimum)
212                    | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
213                        let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
214                        let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
215
216                        // If both masks are 1 one the same point, we want to scale the
217                        // gradient by 0.5 rather than 1.
218                        let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
219                        let lhs_sum_grad = grads.or_insert(lhs)?;
220                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
221
222                        let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
223                        let rhs_sum_grad = grads.or_insert(rhs)?;
224                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
225                    }
226                    Op::WhereCond(pred, t, f) => {
227                        let zeros = grad.zeros_like()?;
228                        let t_sum_grad = grads.or_insert(t)?;
229                        let t_grad = pred.where_cond(&grad, &zeros)?;
230                        *t_sum_grad = t_sum_grad.add(&t_grad)?;
231                        let f_sum_grad = grads.or_insert(f)?;
232                        let f_grad = pred.where_cond(&zeros, &grad)?;
233                        *f_sum_grad = f_sum_grad.add(&f_grad)?;
234                    }
235                    Op::Conv1D {
236                        arg,
237                        kernel,
238                        padding,
239                        stride,
240                        dilation,
241                    } => {
242                        // The output height for conv_transpose1d is:
243                        // (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
244                        let grad_l_in = grad.dim(2)?;
245                        let k_size = kernel.dim(2)?;
246                        let out_size =
247                            (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
248                        let out_padding = arg.dim(2)? - out_size;
249                        let grad_arg = grad.conv_transpose1d(
250                            kernel,
251                            *padding,
252                            out_padding,
253                            *stride,
254                            *dilation,
255                            /* groups */ 1,
256                        )?;
257                        let sum_grad = grads.or_insert(arg)?;
258                        *sum_grad = sum_grad.add(&grad_arg)?;
259
260                        let grad_kernel = arg
261                            .transpose(0, 1)?
262                            .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
263                            .transpose(0, 1)?;
264                        let sum_grad = grads.or_insert(kernel)?;
265                        let (_, _, k0) = kernel.dims3()?;
266                        let (_, _, g_k0) = grad_kernel.dims3()?;
267                        let grad_kernel = if g_k0 != k0 {
268                            grad_kernel.narrow(2, 0, k0)?
269                        } else {
270                            grad_kernel
271                        };
272                        *sum_grad = sum_grad.add(&grad_kernel)?;
273                    }
274                    Op::Conv2D {
275                        arg,
276                        kernel,
277                        padding,
278                        stride,
279                        dilation,
280                    } => {
281                        // The output height for conv_transpose2d is:
282                        // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
283                        let grad_h = grad.dim(2)?;
284                        let k_h = kernel.dim(2)?;
285                        let out_size =
286                            (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
287                        let out_padding = arg.dim(2)? - out_size;
288                        let grad_arg = grad.conv_transpose2d(
289                            kernel,
290                            *padding,
291                            out_padding,
292                            *stride,
293                            *dilation,
294                        )?;
295                        let sum_grad = grads.or_insert(arg)?;
296                        *sum_grad = sum_grad.add(&grad_arg)?;
297
298                        let grad_kernel = arg
299                            .transpose(0, 1)?
300                            .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
301                            .transpose(0, 1)?;
302                        let sum_grad = grads.or_insert(kernel)?;
303                        let (_, _, k0, k1) = kernel.dims4()?;
304                        let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
305                        let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
306                            grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
307                        } else {
308                            grad_kernel
309                        };
310                        *sum_grad = sum_grad.add(&grad_kernel)?;
311                    }
312                    Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
313                        op: "conv-transpose1d",
314                    })?,
315                    Op::ConvTranspose2D {
316                        arg,
317                        kernel,
318                        padding,
319                        stride,
320                        dilation,
321                        output_padding: _output_padding,
322                    } => {
323                        let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
324                        let sum_grad = grads.or_insert(arg)?;
325                        *sum_grad = sum_grad.add(&grad_arg)?;
326
327                        let grad_kernel = grad
328                            .transpose(0, 1)?
329                            .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
330                            .transpose(0, 1)?;
331                        let sum_grad = grads.or_insert(kernel)?;
332                        let (_, _, k0, k1) = kernel.dims4()?;
333                        let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
334                        let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
335                            grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
336                        } else {
337                            grad_kernel
338                        };
339                        *sum_grad = sum_grad.add(&grad_kernel)?;
340                    }
341                    Op::AvgPool2D {
342                        arg,
343                        kernel_size,
344                        stride,
345                    } => {
346                        if kernel_size != stride {
347                            crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
348                        }
349                        let (_n, _c, h, w) = arg.dims4()?;
350                        let grad_arg = grad.upsample_nearest2d(h, w)?;
351                        let grad_arg =
352                            (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
353                        let sum_grad = grads.or_insert(arg)?;
354                        *sum_grad = sum_grad.add(&grad_arg)?;
355                    }
356                    Op::MaxPool2D {
357                        arg,
358                        kernel_size,
359                        stride,
360                    } => {
361                        if kernel_size != stride {
362                            crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
363                        }
364                        let (_n, _c, h, w) = arg.dims4()?;
365                        // For computing the max-pool gradient, we compute a mask where a 1 means
366                        // that the element is the maximum, then we apply this mask to the
367                        // upsampled gradient (taking into account that multiple max may exist so
368                        // we scale the gradient for this case).
369                        let node_upsampled = node.upsample_nearest2d(h, w)?;
370                        let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
371                        let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
372                        let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
373                        let sum_grad = grads.or_insert(arg)?;
374                        *sum_grad = sum_grad.add(&grad_arg)?;
375                    }
376                    Op::UpsampleNearest1D { arg, target_size } => {
377                        let (_n, c, size) = arg.dims3()?;
378                        if target_size % size != 0 {
379                            crate::bail!("backward not supported for non integer upscaling factors")
380                        }
381                        let scale = target_size / size;
382
383                        let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
384                        let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
385                        let sum_grad = grads.or_insert(arg)?;
386                        *sum_grad = conv_sum;
387                    }
388                    Op::UpsampleNearest2D {
389                        arg,
390                        target_h,
391                        target_w,
392                    } => {
393                        let (_n, c, h, w) = arg.dims4()?;
394                        if target_h % h != 0 || target_w % w != 0 {
395                            crate::bail!("backward not supported for non integer upscaling factors")
396                        }
397                        let scale_h = target_h / h;
398                        let scale_w = target_w / w;
399
400                        if scale_h != scale_w {
401                            crate::bail!("backward not supported for non uniform upscaling factors")
402                        };
403                        let kernel =
404                            Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
405                        let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
406                        let sum_grad = grads.or_insert(arg)?;
407                        *sum_grad = conv_sum;
408                    }
409                    Op::SliceScatter0(lhs, rhs, start_rhs) => {
410                        let rhs_sum_grad = grads.or_insert(rhs)?;
411                        let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
412                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
413
414                        let lhs_sum_grad = grads.or_insert(lhs)?;
415                        let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
416                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
417                    }
418                    Op::Gather(arg, indexes, dim) => {
419                        let sum_grad = grads.or_insert(arg)?;
420                        *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
421                    }
422                    Op::ScatterAdd(init, indexes, src, dim) => {
423                        let init_sum_grad = grads.or_insert(init)?;
424                        *init_sum_grad = init_sum_grad.add(&grad)?;
425
426                        let src_grad = grad.gather(indexes, *dim)?;
427                        let src_sum_grad = grads.or_insert(src)?;
428                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
429                    }
430                    Op::IndexAdd(init, indexes, src, dim) => {
431                        let init_sum_grad = grads.or_insert(init)?;
432                        *init_sum_grad = init_sum_grad.add(&grad)?;
433
434                        let src_grad = grad.index_select(indexes, *dim)?;
435                        let src_sum_grad = grads.or_insert(src)?;
436                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
437                    }
438                    Op::IndexSelect(arg, indexes, dim) => {
439                        let sum_grad = grads.or_insert(arg)?;
440                        *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
441                    }
442                    Op::Matmul(lhs, rhs) => {
443                        // Skipping checks, the op went ok, we can skip
444                        // the matmul size checks for now.
445
446                        let lhs_grad = grad.matmul(&rhs.t()?)?;
447                        let lhs_sum_grad = grads.or_insert(lhs)?;
448                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
449
450                        let rhs_grad = lhs.t()?.matmul(&grad)?;
451                        let rhs_sum_grad = grads.or_insert(rhs)?;
452                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
453                    }
454                    Op::Cat(args, dim) => {
455                        let mut start_idx = 0;
456                        for arg in args {
457                            let len = arg.dims()[*dim];
458                            let arg_grad = grad.narrow(*dim, start_idx, len)?;
459                            let sum_grad = grads.or_insert(arg)?;
460                            *sum_grad = sum_grad.add(&arg_grad)?;
461                            start_idx += len;
462                        }
463                    }
464                    Op::Broadcast(arg) => {
465                        let arg_dims = arg.dims();
466                        let node_dims = node.dims();
467                        // The number of dims that have been inserted on the left.
468                        let left_dims = node_dims.len() - arg_dims.len();
469                        let mut sum_dims: Vec<usize> = (0..left_dims).collect();
470                        for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
471                            .iter()
472                            .zip(arg_dims.iter())
473                            .enumerate()
474                        {
475                            if node_dim != arg_dim {
476                                sum_dims.push(dim + left_dims)
477                            }
478                        }
479
480                        let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
481                        for _i in 0..left_dims {
482                            arg_grad = arg_grad.squeeze(0)?
483                        }
484                        let sum_grad = grads.or_insert(arg)?;
485                        *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
486                    }
487                    Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
488                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
489                        let sum_grad = grads.or_insert(arg)?;
490                        *sum_grad = sum_grad.add(&grad)?;
491                    }
492                    Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
493                        let node = broadcast_back(arg, node, reduced_dims)?;
494                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
495                        let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
496                        let sum_grad = grads.or_insert(arg)?;
497                        *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
498                    }
499                    Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
500                        let node = broadcast_back(arg, node, reduced_dims)?;
501                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
502                        let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
503                        let sum_grad = grads.or_insert(arg)?;
504                        *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
505                    }
506                    Op::ToDType(arg) => {
507                        let sum_grad = grads.or_insert(arg)?;
508                        *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
509                    }
510                    Op::Copy(arg) => {
511                        let sum_grad = grads.or_insert(arg)?;
512                        *sum_grad = sum_grad.add(&grad)?
513                    }
514                    Op::Affine { arg, mul, .. } => {
515                        let arg_grad = grad.affine(*mul, 0.)?;
516                        let sum_grad = grads.or_insert(arg)?;
517                        *sum_grad = sum_grad.add(&arg_grad)?
518                    }
519                    Op::Unary(arg, UnaryOp::Log) => {
520                        let sum_grad = grads.or_insert(arg)?;
521                        *sum_grad = sum_grad.add(&(grad / arg)?)?
522                    }
523                    Op::Unary(arg, UnaryOp::Sin) => {
524                        let sum_grad = grads.or_insert(arg)?;
525                        *sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
526                    }
527                    Op::Unary(arg, UnaryOp::Cos) => {
528                        let sum_grad = grads.or_insert(arg)?;
529                        *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
530                    }
531                    Op::Unary(arg, UnaryOp::Tanh) => {
532                        let sum_grad = grads.or_insert(arg)?;
533                        let minus_dtanh = (node.sqr()? - 1.)?;
534                        *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
535                    }
536                    Op::Unary(arg, UnaryOp::Abs) => {
537                        let sum_grad = grads.or_insert(arg)?;
538                        let ones = arg.ones_like()?;
539                        let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
540                        *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
541                    }
542                    Op::Unary(arg, UnaryOp::Exp) => {
543                        let sum_grad = grads.or_insert(arg)?;
544                        *sum_grad = sum_grad.add(&(&grad * *node)?)?
545                    }
546                    Op::Unary(arg, UnaryOp::Neg) => {
547                        let sum_grad = grads.or_insert(arg)?;
548                        *sum_grad = sum_grad.sub(&grad)?
549                    }
550                    Op::Unary(arg, UnaryOp::Recip) => {
551                        let sum_grad = grads.or_insert(arg)?;
552                        let grad = (grad / arg.sqr()?)?;
553                        *sum_grad = sum_grad.sub(&grad)?
554                    }
555                    &Op::Narrow(ref arg, dim, start_idx, len) => {
556                        let arg_dims = arg.dims();
557                        let left_pad = if start_idx == 0 {
558                            None
559                        } else {
560                            let mut dims = arg_dims.to_vec();
561                            dims[dim] = start_idx;
562                            Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
563                        };
564                        let right_pad = arg_dims[dim] - start_idx - len;
565                        let right_pad = if right_pad == 0 {
566                            None
567                        } else {
568                            let mut dims = arg_dims.to_vec();
569                            dims[dim] = right_pad;
570                            Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
571                        };
572                        let arg_grad = match (left_pad, right_pad) {
573                            (None, None) => grad,
574                            (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
575                            (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
576                            (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
577                        };
578                        let sum_grad = grads.or_insert(arg)?;
579                        *sum_grad = sum_grad.add(&arg_grad)?
580                    }
581                    Op::Unary(_, UnaryOp::Floor)
582                    | Op::Unary(_, UnaryOp::Round)
583                    | Op::Reduce(_, ReduceOp::ArgMin, _)
584                    | Op::Reduce(_, ReduceOp::ArgMax, _)
585                    | Op::Unary(_, UnaryOp::Sign)
586                    | Op::Cmp(_, _) => {}
587                    Op::Reshape(arg) => {
588                        let arg_grad = grad.reshape(arg.dims())?;
589                        let sum_grad = grads.or_insert(arg)?;
590                        *sum_grad = sum_grad.add(&arg_grad)?
591                    }
592                    Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
593                    Op::Unary(arg, UnaryOp::Gelu) => {
594                        let sum_grad = grads.or_insert(arg)?;
595                        let cube = arg.powf(3.)?;
596                        let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
597                        let gelu_grad = (((0.5 * &tanh)?
598                            + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
599                            + 0.5)?;
600                        *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
601                    }
602                    Op::Unary(arg, UnaryOp::Erf) => {
603                        let sum_grad = grads.or_insert(arg)?;
604                        // d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
605                        let erf_grad =
606                            (2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
607                        *sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
608                    }
609                    Op::Unary(arg, UnaryOp::GeluErf) => {
610                        let sum_grad = grads.or_insert(arg)?;
611                        // d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
612                        let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
613                        let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
614                        let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
615                        let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
616                        let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
617                        *sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
618                    }
619                    Op::Unary(arg, UnaryOp::Relu) => {
620                        let sum_grad = grads.or_insert(arg)?;
621                        let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
622                        *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
623                    }
624                    Op::Unary(arg, UnaryOp::Silu) => {
625                        let sum_grad = grads.or_insert(arg)?;
626                        // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
627                        let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
628                        let silu_grad = &sigmoid_arg * (1. - *node) + *node;
629                        *sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
630                    }
631                    Op::Elu(arg, alpha) => {
632                        // d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
633                        let sum_grad = grads.or_insert(arg)?;
634                        let zeros = arg.zeros_like()?;
635                        let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
636                        let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
637                        // node == alpha * (e^x - 1) for x <= 0, reuse it
638                        let negative_exp_mask = (negative_mask * (*node + *alpha))?;
639                        let combined_mask = (positive_mask + negative_exp_mask)?;
640                        *sum_grad = sum_grad.add(&(grad * combined_mask)?)?
641                    }
642                    Op::Powf(arg, e) => {
643                        let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
644                        let sum_grad = grads.or_insert(arg)?;
645                        *sum_grad = sum_grad.add(&arg_grad)?
646                    }
647                    Op::CustomOp1(arg, c) => {
648                        if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
649                            let sum_grad = grads.or_insert(arg)?;
650                            *sum_grad = sum_grad.add(&arg_grad)?
651                        }
652                    }
653                    Op::CustomOp2(arg1, arg2, c) => {
654                        let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
655                        if let Some(arg_grad1) = arg_grad1 {
656                            let sum_grad = grads.or_insert(arg1)?;
657                            *sum_grad = sum_grad.add(&arg_grad1)?
658                        }
659                        if let Some(arg_grad2) = arg_grad2 {
660                            let sum_grad = grads.or_insert(arg2)?;
661                            *sum_grad = sum_grad.add(&arg_grad2)?
662                        }
663                    }
664                    Op::CustomOp3(arg1, arg2, arg3, c) => {
665                        let (arg_grad1, arg_grad2, arg_grad3) =
666                            c.bwd(arg1, arg2, arg3, node, &grad)?;
667                        if let Some(arg_grad1) = arg_grad1 {
668                            let sum_grad = grads.or_insert(arg1)?;
669                            *sum_grad = sum_grad.add(&arg_grad1)?
670                        }
671                        if let Some(arg_grad2) = arg_grad2 {
672                            let sum_grad = grads.or_insert(arg2)?;
673                            *sum_grad = sum_grad.add(&arg_grad2)?
674                        }
675                        if let Some(arg_grad3) = arg_grad3 {
676                            let sum_grad = grads.or_insert(arg3)?;
677                            *sum_grad = sum_grad.add(&arg_grad3)?
678                        }
679                    }
680                    Op::Unary(arg, UnaryOp::Sqr) => {
681                        let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
682                        let sum_grad = grads.or_insert(arg)?;
683                        *sum_grad = sum_grad.add(&arg_grad)?
684                    }
685                    Op::Unary(arg, UnaryOp::Sqrt) => {
686                        let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
687                        let sum_grad = grads.or_insert(arg)?;
688                        *sum_grad = sum_grad.add(&arg_grad)?
689                    }
690                    Op::ToDevice(arg) => {
691                        let sum_grad = grads.or_insert(arg)?;
692                        let arg_grad = grad.to_device(sum_grad.device())?;
693                        *sum_grad = sum_grad.add(&arg_grad)?
694                    }
695                    Op::Transpose(arg, dim1, dim2) => {
696                        let arg_grad = grad.transpose(*dim1, *dim2)?;
697                        let sum_grad = grads.or_insert(arg)?;
698                        *sum_grad = sum_grad.add(&arg_grad)?
699                    }
700                    Op::Permute(arg, dims) => {
701                        let mut inv_dims = vec![0; dims.len()];
702                        for (i, &dim_idx) in dims.iter().enumerate() {
703                            inv_dims[dim_idx] = i
704                        }
705                        let arg_grad = grad.permute(inv_dims)?;
706                        let sum_grad = grads.or_insert(arg)?;
707                        *sum_grad = sum_grad.add(&arg_grad)?
708                    }
709                };
710            }
711        }
712        Ok(grads)
713    }
714}
715
716/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
717#[derive(Debug)]
718pub struct GradStore(HashMap<TensorId, Tensor>);
719
720impl GradStore {
721    /// Create a new gradient store
722    fn new() -> Self {
723        GradStore(HashMap::new())
724    }
725
726    /// Get the gradient tensor corresponding to the given tensor id
727    pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
728        self.0.get(&id)
729    }
730
731    /// Get the gradient tensor associated with the given tensor
732    pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
733        self.0.get(&tensor.id())
734    }
735
736    /// Remove the gradient tensor associated with the given tensor, returning it if it exists
737    pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
738        self.0.remove(&tensor.id())
739    }
740
741    /// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
742    pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
743        self.0.insert(tensor.id(), grad)
744    }
745
746    /// Get the gradient tensor associated with the given tensor, or, if it does not exist,
747    /// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
748    fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
749        use std::collections::hash_map::Entry;
750        let grad = match self.0.entry(tensor.id()) {
751            Entry::Occupied(entry) => entry.into_mut(),
752            Entry::Vacant(entry) => {
753                let grad = tensor.zeros_like()?;
754                entry.insert(grad)
755            }
756        };
757        Ok(grad)
758    }
759
760    /// Get the tensor ids of the stored gradient tensors
761    pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
762        self.0.keys()
763    }
764}