candle_core/
backprop.rs

1//! Methods for backpropagation of gradients.
2use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
3use crate::{Error, Result, Tensor, TensorId};
4use std::collections::HashMap;
5
6// arg has been reduced to node via reduce_dims, expand it back to arg.
7// This has to handle keepdims.
8fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
9    if arg.rank() == node.rank() {
10        // keepdim = true
11        node.broadcast_as(arg.shape())
12    } else {
13        // keepdim = false
14        // first expand the reduced dims.
15        node.reshape(reduced_dims)?.broadcast_as(arg.shape())
16    }
17}
18
19thread_local! {
20    static CANDLE_GRAD_DO_NOT_DETACH: bool = {
21        match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
22            Ok(s) => {
23                !s.is_empty() && s != "0"
24            },
25            Err(_) => false,
26        }
27    }
28}
29
30impl Tensor {
31    /// Return all the nodes that lead to this value in a topologically sorted vec, the first
32    /// elements having dependencies on the latter ones, e.g. the first element if any is the
33    /// argument.
34    /// This assumes that the op graph is a DAG.
35    pub fn sorted_nodes(&self) -> Vec<&Tensor> {
36        // The vec of sorted nodes is passed as an owned value rather than a mutable reference
37        // to get around some lifetime limitations.
38        fn walk<'a>(
39            node: &'a Tensor,
40            nodes: Vec<&'a Tensor>,
41            already_seen: &mut HashMap<TensorId, bool>,
42        ) -> (bool, Vec<&'a Tensor>) {
43            if let Some(&tg) = already_seen.get(&node.id()) {
44                return (tg, nodes);
45            }
46            let mut track_grad = false;
47            let mut nodes = if node.is_variable() {
48                // Do not call recursively on the "leaf" nodes.
49                track_grad = true;
50                nodes
51            } else if node.dtype().is_int() {
52                nodes
53            } else if let Some(op) = node.op() {
54                match op {
55                    Op::IndexAdd(t1, t2, t3, _)
56                    | Op::Scatter(t1, t2, t3, _)
57                    | Op::ScatterAdd(t1, t2, t3, _)
58                    | Op::CustomOp3(t1, t2, t3, _)
59                    | Op::WhereCond(t1, t2, t3) => {
60                        let (tg, nodes) = walk(t1, nodes, already_seen);
61                        track_grad |= tg;
62                        let (tg, nodes) = walk(t2, nodes, already_seen);
63                        track_grad |= tg;
64                        let (tg, nodes) = walk(t3, nodes, already_seen);
65                        track_grad |= tg;
66                        nodes
67                    }
68                    Op::Conv1D {
69                        arg: lhs,
70                        kernel: rhs,
71                        ..
72                    }
73                    | Op::ConvTranspose1D {
74                        arg: lhs,
75                        kernel: rhs,
76                        ..
77                    }
78                    | Op::Conv2D {
79                        arg: lhs,
80                        kernel: rhs,
81                        ..
82                    }
83                    | Op::ConvTranspose2D {
84                        arg: lhs,
85                        kernel: rhs,
86                        ..
87                    }
88                    | Op::CustomOp2(lhs, rhs, _)
89                    | Op::Binary(lhs, rhs, _)
90                    | Op::Gather(lhs, rhs, _)
91                    | Op::IndexSelect(lhs, rhs, _)
92                    | Op::Matmul(lhs, rhs)
93                    | Op::SliceScatter0(lhs, rhs, _) => {
94                        let (tg, nodes) = walk(lhs, nodes, already_seen);
95                        track_grad |= tg;
96                        let (tg, nodes) = walk(rhs, nodes, already_seen);
97                        track_grad |= tg;
98                        nodes
99                    }
100                    Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
101                        let (tg, nodes) = walk(arg, nodes, already_seen);
102                        track_grad |= tg;
103                        nodes
104                    }),
105                    Op::Affine { arg, mul, .. } => {
106                        if *mul == 0. {
107                            nodes
108                        } else {
109                            let (tg, nodes) = walk(arg, nodes, already_seen);
110                            track_grad |= tg;
111                            nodes
112                        }
113                    }
114                    Op::Unary(_node, UnaryOp::Ceil)
115                    | Op::Unary(_node, UnaryOp::Floor)
116                    | Op::Unary(_node, UnaryOp::Round)
117                    | Op::Unary(_node, UnaryOp::Sign) => nodes,
118                    Op::Reshape(node)
119                    | Op::UpsampleNearest1D { arg: node, .. }
120                    | Op::UpsampleNearest2D { arg: node, .. }
121                    | Op::UpsampleBilinear2D { arg: node, .. }
122                    | Op::AvgPool2D { arg: node, .. }
123                    | Op::MaxPool2D { arg: node, .. }
124                    | Op::Copy(node)
125                    | Op::Broadcast(node)
126                    | Op::Cmp(node, _)
127                    | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
128                    | Op::ToDevice(node)
129                    | Op::Transpose(node, _, _)
130                    | Op::Permute(node, _)
131                    | Op::Narrow(node, _, _, _)
132                    | Op::Unary(node, _)
133                    | Op::Elu(node, _)
134                    | Op::Powf(node, _)
135                    | Op::CustomOp1(node, _) => {
136                        let (tg, nodes) = walk(node, nodes, already_seen);
137                        track_grad |= tg;
138                        nodes
139                    }
140                    Op::ToDType(node) => {
141                        if node.dtype().is_float() {
142                            let (tg, nodes) = walk(node, nodes, already_seen);
143                            track_grad |= tg;
144                            nodes
145                        } else {
146                            nodes
147                        }
148                    }
149                    Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
150                }
151            } else {
152                nodes
153            };
154            already_seen.insert(node.id(), track_grad);
155            if track_grad {
156                nodes.push(node);
157            }
158            (track_grad, nodes)
159        }
160        let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
161        nodes.reverse();
162        nodes
163    }
164
165    pub fn backward(&self) -> Result<GradStore> {
166        let sorted_nodes = self.sorted_nodes();
167        let mut grads = GradStore::new();
168        grads.insert(self, self.ones_like()?.contiguous()?);
169        for node in sorted_nodes.iter() {
170            if node.is_variable() {
171                continue;
172            }
173            let grad = grads
174                .remove(node)
175                .expect("candle internal error - grad not populated");
176            // https://github.com/huggingface/candle/issues/1241
177            // Ideally, we would make these operations in place where possible to ensure that we
178            // do not have to allocate too often. Here we just call `.detach` to avoid computing
179            // the backprop graph of the backprop itself. This would be an issue for second order
180            // derivatives but these are out of scope at the moment.
181            let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
182            let grad = if do_not_detach { grad } else { grad.detach() };
183            if let Some(op) = node.op() {
184                match op {
185                    Op::Binary(lhs, rhs, BinaryOp::Add) => {
186                        let lhs_sum_grad = grads.or_insert(lhs)?;
187                        *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
188                        let rhs_sum_grad = grads.or_insert(rhs)?;
189                        *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
190                    }
191                    Op::Binary(lhs, rhs, BinaryOp::Sub) => {
192                        let lhs_sum_grad = grads.or_insert(lhs)?;
193                        *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
194                        let rhs_sum_grad = grads.or_insert(rhs)?;
195                        *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
196                    }
197                    Op::Binary(lhs, rhs, BinaryOp::Mul) => {
198                        let lhs_grad = grad.mul(rhs)?;
199                        let lhs_sum_grad = grads.or_insert(lhs)?;
200                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
201                        let rhs_grad = grad.mul(lhs)?;
202                        let rhs_sum_grad = grads.or_insert(rhs)?;
203                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
204                    }
205                    Op::Binary(lhs, rhs, BinaryOp::Div) => {
206                        let lhs_grad = grad.div(rhs)?;
207                        let lhs_sum_grad = grads.or_insert(lhs)?;
208                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
209                        let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
210                        let rhs_sum_grad = grads.or_insert(rhs)?;
211                        *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
212                    }
213                    Op::Binary(lhs, rhs, BinaryOp::Minimum)
214                    | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
215                        let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
216                        let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
217
218                        // If both masks are 1 one the same point, we want to scale the
219                        // gradient by 0.5 rather than 1.
220                        let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
221                        let lhs_sum_grad = grads.or_insert(lhs)?;
222                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
223
224                        let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
225                        let rhs_sum_grad = grads.or_insert(rhs)?;
226                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
227                    }
228                    Op::WhereCond(pred, t, f) => {
229                        let zeros = grad.zeros_like()?;
230                        let t_sum_grad = grads.or_insert(t)?;
231                        let t_grad = pred.where_cond(&grad, &zeros)?;
232                        *t_sum_grad = t_sum_grad.add(&t_grad)?;
233                        let f_sum_grad = grads.or_insert(f)?;
234                        let f_grad = pred.where_cond(&zeros, &grad)?;
235                        *f_sum_grad = f_sum_grad.add(&f_grad)?;
236                    }
237                    Op::Conv1D {
238                        arg,
239                        kernel,
240                        padding,
241                        stride,
242                        dilation,
243                    } => {
244                        // The output height for conv_transpose1d is:
245                        // (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
246                        let grad_l_in = grad.dim(2)?;
247                        let k_size = kernel.dim(2)?;
248                        let out_size =
249                            (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
250                        let out_padding = arg.dim(2)? - out_size;
251                        let grad_arg = grad.conv_transpose1d(
252                            kernel,
253                            *padding,
254                            out_padding,
255                            *stride,
256                            *dilation,
257                            /* groups */ 1,
258                        )?;
259                        let sum_grad = grads.or_insert(arg)?;
260                        *sum_grad = sum_grad.add(&grad_arg)?;
261
262                        let grad_kernel = arg
263                            .transpose(0, 1)?
264                            .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
265                            .transpose(0, 1)?;
266                        let sum_grad = grads.or_insert(kernel)?;
267                        let (_, _, k0) = kernel.dims3()?;
268                        let (_, _, g_k0) = grad_kernel.dims3()?;
269                        let grad_kernel = if g_k0 != k0 {
270                            grad_kernel.narrow(2, 0, k0)?
271                        } else {
272                            grad_kernel
273                        };
274                        *sum_grad = sum_grad.add(&grad_kernel)?;
275                    }
276                    Op::Conv2D {
277                        arg,
278                        kernel,
279                        padding,
280                        stride,
281                        dilation,
282                    } => {
283                        // The output height for conv_transpose2d is:
284                        // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
285                        let grad_h = grad.dim(2)?;
286                        let k_h = kernel.dim(2)?;
287                        let out_size =
288                            (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
289                        let out_padding = arg.dim(2)? - out_size;
290                        let grad_arg = grad.conv_transpose2d(
291                            kernel,
292                            *padding,
293                            out_padding,
294                            *stride,
295                            *dilation,
296                        )?;
297                        let sum_grad = grads.or_insert(arg)?;
298                        *sum_grad = sum_grad.add(&grad_arg)?;
299
300                        let grad_kernel = arg
301                            .transpose(0, 1)?
302                            .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
303                            .transpose(0, 1)?;
304                        let sum_grad = grads.or_insert(kernel)?;
305                        let (_, _, k0, k1) = kernel.dims4()?;
306                        let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
307                        let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
308                            grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
309                        } else {
310                            grad_kernel
311                        };
312                        *sum_grad = sum_grad.add(&grad_kernel)?;
313                    }
314                    Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
315                        op: "conv-transpose1d",
316                    })?,
317                    Op::ConvTranspose2D {
318                        arg,
319                        kernel,
320                        padding,
321                        stride,
322                        dilation,
323                        output_padding: _output_padding,
324                    } => {
325                        let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
326                        let sum_grad = grads.or_insert(arg)?;
327                        *sum_grad = sum_grad.add(&grad_arg)?;
328
329                        let grad_kernel = grad
330                            .transpose(0, 1)?
331                            .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
332                            .transpose(0, 1)?;
333                        let sum_grad = grads.or_insert(kernel)?;
334                        let (_, _, k0, k1) = kernel.dims4()?;
335                        let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
336                        let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
337                            grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
338                        } else {
339                            grad_kernel
340                        };
341                        *sum_grad = sum_grad.add(&grad_kernel)?;
342                    }
343                    Op::AvgPool2D {
344                        arg,
345                        kernel_size,
346                        stride,
347                    } => {
348                        if kernel_size != stride {
349                            crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
350                        }
351                        let (_n, _c, h, w) = arg.dims4()?;
352                        let grad_arg = grad.upsample_nearest2d(h, w)?;
353                        let grad_arg =
354                            (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
355                        let sum_grad = grads.or_insert(arg)?;
356                        *sum_grad = sum_grad.add(&grad_arg)?;
357                    }
358                    Op::MaxPool2D {
359                        arg,
360                        kernel_size,
361                        stride,
362                    } => {
363                        if kernel_size != stride {
364                            crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
365                        }
366                        let (_n, _c, h, w) = arg.dims4()?;
367                        // For computing the max-pool gradient, we compute a mask where a 1 means
368                        // that the element is the maximum, then we apply this mask to the
369                        // upsampled gradient (taking into account that multiple max may exist so
370                        // we scale the gradient for this case).
371                        let node_upsampled = node.upsample_nearest2d(h, w)?;
372                        let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
373                        let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
374                        let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
375                        let sum_grad = grads.or_insert(arg)?;
376                        *sum_grad = sum_grad.add(&grad_arg)?;
377                    }
378                    Op::UpsampleNearest1D { arg, target_size } => {
379                        let (_n, c, size) = arg.dims3()?;
380                        if target_size % size != 0 {
381                            crate::bail!("backward not supported for non integer upscaling factors")
382                        }
383                        let scale = target_size / size;
384
385                        let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
386                        let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
387                        let sum_grad = grads.or_insert(arg)?;
388                        *sum_grad = conv_sum;
389                    }
390                    Op::UpsampleNearest2D {
391                        arg,
392                        target_h,
393                        target_w,
394                    } => {
395                        let (_n, c, h, w) = arg.dims4()?;
396                        if target_h % h != 0 || target_w % w != 0 {
397                            crate::bail!("backward not supported for non integer upscaling factors")
398                        }
399                        let scale_h = target_h / h;
400                        let scale_w = target_w / w;
401
402                        if scale_h != scale_w {
403                            crate::bail!("backward not supported for non uniform upscaling factors")
404                        };
405                        let kernel =
406                            Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
407                        let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
408                        let sum_grad = grads.or_insert(arg)?;
409                        *sum_grad = conv_sum;
410                    }
411                    Op::UpsampleBilinear2D { .. } => {
412                        crate::bail!("backward not supported for upsample_bilinear2d")
413                    }
414                    Op::SliceScatter0(lhs, rhs, start_rhs) => {
415                        let rhs_sum_grad = grads.or_insert(rhs)?;
416                        let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
417                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
418
419                        let lhs_sum_grad = grads.or_insert(lhs)?;
420                        let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
421                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
422                    }
423                    Op::Gather(arg, indexes, dim) => {
424                        let sum_grad = grads.or_insert(arg)?;
425                        *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
426                    }
427                    Op::Scatter(init, indexes, src, dim) => {
428                        let init_sum_grad = grads.or_insert(init)?;
429                        *init_sum_grad = init_sum_grad.add(&grad)?;
430
431                        let src_grad = grad.gather(indexes, *dim)?;
432                        let src_sum_grad = grads.or_insert(src)?;
433                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
434                    }
435                    Op::ScatterAdd(init, indexes, src, dim) => {
436                        let init_sum_grad = grads.or_insert(init)?;
437                        let mask = init.ones_like()?;
438                        let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
439                        *init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
440
441                        let src_grad = grad.gather(indexes, *dim)?;
442                        let src_sum_grad = grads.or_insert(src)?;
443                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
444                    }
445                    Op::IndexAdd(init, indexes, src, dim) => {
446                        let init_sum_grad = grads.or_insert(init)?;
447                        *init_sum_grad = init_sum_grad.add(&grad)?;
448
449                        let src_grad = grad.index_select(indexes, *dim)?;
450                        let src_sum_grad = grads.or_insert(src)?;
451                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
452                    }
453                    Op::IndexSelect(arg, indexes, dim) => {
454                        let sum_grad = grads.or_insert(arg)?;
455                        *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
456                    }
457                    Op::Matmul(lhs, rhs) => {
458                        // Skipping checks, the op went ok, we can skip
459                        // the matmul size checks for now.
460
461                        let lhs_grad = grad.matmul(&rhs.t()?)?;
462                        let lhs_sum_grad = grads.or_insert(lhs)?;
463                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
464
465                        let rhs_grad = lhs.t()?.matmul(&grad)?;
466                        let rhs_sum_grad = grads.or_insert(rhs)?;
467                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
468                    }
469                    Op::Cat(args, dim) => {
470                        let mut start_idx = 0;
471                        for arg in args {
472                            let len = arg.dims()[*dim];
473                            let arg_grad = grad.narrow(*dim, start_idx, len)?;
474                            let sum_grad = grads.or_insert(arg)?;
475                            *sum_grad = sum_grad.add(&arg_grad)?;
476                            start_idx += len;
477                        }
478                    }
479                    Op::Broadcast(arg) => {
480                        let arg_dims = arg.dims();
481                        let node_dims = node.dims();
482                        // The number of dims that have been inserted on the left.
483                        let left_dims = node_dims.len() - arg_dims.len();
484                        let mut sum_dims: Vec<usize> = (0..left_dims).collect();
485                        for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
486                            .iter()
487                            .zip(arg_dims.iter())
488                            .enumerate()
489                        {
490                            if node_dim != arg_dim {
491                                sum_dims.push(dim + left_dims)
492                            }
493                        }
494
495                        let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
496                        for _i in 0..left_dims {
497                            arg_grad = arg_grad.squeeze(0)?
498                        }
499                        let sum_grad = grads.or_insert(arg)?;
500                        *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
501                    }
502                    Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
503                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
504                        let sum_grad = grads.or_insert(arg)?;
505                        *sum_grad = sum_grad.add(&grad)?;
506                    }
507                    Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
508                        let node = broadcast_back(arg, node, reduced_dims)?;
509                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
510                        let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
511                        let sum_grad = grads.or_insert(arg)?;
512                        *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
513                    }
514                    Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
515                        let node = broadcast_back(arg, node, reduced_dims)?;
516                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
517                        let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
518                        let sum_grad = grads.or_insert(arg)?;
519                        *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
520                    }
521                    Op::ToDType(arg) => {
522                        let sum_grad = grads.or_insert(arg)?;
523                        *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
524                    }
525                    Op::Copy(arg) => {
526                        let sum_grad = grads.or_insert(arg)?;
527                        *sum_grad = sum_grad.add(&grad)?
528                    }
529                    Op::Affine { arg, mul, .. } => {
530                        let arg_grad = grad.affine(*mul, 0.)?;
531                        let sum_grad = grads.or_insert(arg)?;
532                        *sum_grad = sum_grad.add(&arg_grad)?
533                    }
534                    Op::Unary(arg, UnaryOp::Log) => {
535                        let sum_grad = grads.or_insert(arg)?;
536                        *sum_grad = sum_grad.add(&(grad / arg)?)?
537                    }
538                    Op::Unary(arg, UnaryOp::Sin) => {
539                        let sum_grad = grads.or_insert(arg)?;
540                        *sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
541                    }
542                    Op::Unary(arg, UnaryOp::Cos) => {
543                        let sum_grad = grads.or_insert(arg)?;
544                        *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
545                    }
546                    Op::Unary(arg, UnaryOp::Tanh) => {
547                        let sum_grad = grads.or_insert(arg)?;
548                        let minus_dtanh = (node.sqr()? - 1.)?;
549                        *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
550                    }
551                    Op::Unary(arg, UnaryOp::Abs) => {
552                        let sum_grad = grads.or_insert(arg)?;
553                        let ones = arg.ones_like()?;
554                        let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
555                        *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
556                    }
557                    Op::Unary(arg, UnaryOp::Exp) => {
558                        let sum_grad = grads.or_insert(arg)?;
559                        *sum_grad = sum_grad.add(&(&grad * *node)?)?
560                    }
561                    Op::Unary(arg, UnaryOp::Neg) => {
562                        let sum_grad = grads.or_insert(arg)?;
563                        *sum_grad = sum_grad.sub(&grad)?
564                    }
565                    Op::Unary(arg, UnaryOp::Recip) => {
566                        let sum_grad = grads.or_insert(arg)?;
567                        let grad = (grad / arg.sqr()?)?;
568                        *sum_grad = sum_grad.sub(&grad)?
569                    }
570                    &Op::Narrow(ref arg, dim, start_idx, len) => {
571                        let arg_dims = arg.dims();
572                        let left_pad = if start_idx == 0 {
573                            None
574                        } else {
575                            let mut dims = arg_dims.to_vec();
576                            dims[dim] = start_idx;
577                            Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
578                        };
579                        let right_pad = arg_dims[dim] - start_idx - len;
580                        let right_pad = if right_pad == 0 {
581                            None
582                        } else {
583                            let mut dims = arg_dims.to_vec();
584                            dims[dim] = right_pad;
585                            Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
586                        };
587                        let arg_grad = match (left_pad, right_pad) {
588                            (None, None) => grad,
589                            (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
590                            (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
591                            (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
592                        };
593                        let sum_grad = grads.or_insert(arg)?;
594                        *sum_grad = sum_grad.add(&arg_grad)?
595                    }
596                    Op::Unary(_, UnaryOp::Floor)
597                    | Op::Unary(_, UnaryOp::Round)
598                    | Op::Reduce(_, ReduceOp::ArgMin, _)
599                    | Op::Reduce(_, ReduceOp::ArgMax, _)
600                    | Op::Unary(_, UnaryOp::Sign)
601                    | Op::Cmp(_, _) => {}
602                    Op::Reshape(arg) => {
603                        let arg_grad = grad.reshape(arg.dims())?;
604                        let sum_grad = grads.or_insert(arg)?;
605                        *sum_grad = sum_grad.add(&arg_grad)?
606                    }
607                    Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
608                    Op::Unary(arg, UnaryOp::Gelu) => {
609                        let sum_grad = grads.or_insert(arg)?;
610                        let cube = arg.powf(3.)?;
611                        let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
612                        let gelu_grad = (((0.5 * &tanh)?
613                            + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
614                            + 0.5)?;
615                        *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
616                    }
617                    Op::Unary(arg, UnaryOp::Erf) => {
618                        let sum_grad = grads.or_insert(arg)?;
619                        // d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
620                        let erf_grad =
621                            (2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
622                        *sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
623                    }
624                    Op::Unary(arg, UnaryOp::GeluErf) => {
625                        let sum_grad = grads.or_insert(arg)?;
626                        // d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
627                        let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
628                        let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
629                        let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
630                        let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
631                        let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
632                        *sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
633                    }
634                    Op::Unary(arg, UnaryOp::Relu) => {
635                        let sum_grad = grads.or_insert(arg)?;
636                        let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
637                        *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
638                    }
639                    Op::Unary(arg, UnaryOp::Silu) => {
640                        let sum_grad = grads.or_insert(arg)?;
641                        // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
642                        let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
643                        let silu_grad = &sigmoid_arg * (1. - *node) + *node;
644                        *sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
645                    }
646                    Op::Elu(arg, alpha) => {
647                        // d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
648                        let sum_grad = grads.or_insert(arg)?;
649                        let zeros = arg.zeros_like()?;
650                        let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
651                        let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
652                        // node == alpha * (e^x - 1) for x <= 0, reuse it
653                        let negative_exp_mask = (negative_mask * (*node + *alpha))?;
654                        let combined_mask = (positive_mask + negative_exp_mask)?;
655                        *sum_grad = sum_grad.add(&(grad * combined_mask)?)?
656                    }
657                    Op::Powf(arg, e) => {
658                        let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
659                        let sum_grad = grads.or_insert(arg)?;
660                        *sum_grad = sum_grad.add(&arg_grad)?
661                    }
662                    Op::CustomOp1(arg, c) => {
663                        if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
664                            let sum_grad = grads.or_insert(arg)?;
665                            *sum_grad = sum_grad.add(&arg_grad)?
666                        }
667                    }
668                    Op::CustomOp2(arg1, arg2, c) => {
669                        let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
670                        if let Some(arg_grad1) = arg_grad1 {
671                            let sum_grad = grads.or_insert(arg1)?;
672                            *sum_grad = sum_grad.add(&arg_grad1)?
673                        }
674                        if let Some(arg_grad2) = arg_grad2 {
675                            let sum_grad = grads.or_insert(arg2)?;
676                            *sum_grad = sum_grad.add(&arg_grad2)?
677                        }
678                    }
679                    Op::CustomOp3(arg1, arg2, arg3, c) => {
680                        let (arg_grad1, arg_grad2, arg_grad3) =
681                            c.bwd(arg1, arg2, arg3, node, &grad)?;
682                        if let Some(arg_grad1) = arg_grad1 {
683                            let sum_grad = grads.or_insert(arg1)?;
684                            *sum_grad = sum_grad.add(&arg_grad1)?
685                        }
686                        if let Some(arg_grad2) = arg_grad2 {
687                            let sum_grad = grads.or_insert(arg2)?;
688                            *sum_grad = sum_grad.add(&arg_grad2)?
689                        }
690                        if let Some(arg_grad3) = arg_grad3 {
691                            let sum_grad = grads.or_insert(arg3)?;
692                            *sum_grad = sum_grad.add(&arg_grad3)?
693                        }
694                    }
695                    Op::Unary(arg, UnaryOp::Sqr) => {
696                        let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
697                        let sum_grad = grads.or_insert(arg)?;
698                        *sum_grad = sum_grad.add(&arg_grad)?
699                    }
700                    Op::Unary(arg, UnaryOp::Sqrt) => {
701                        let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
702                        let sum_grad = grads.or_insert(arg)?;
703                        *sum_grad = sum_grad.add(&arg_grad)?
704                    }
705                    Op::ToDevice(arg) => {
706                        let sum_grad = grads.or_insert(arg)?;
707                        let arg_grad = grad.to_device(sum_grad.device())?;
708                        *sum_grad = sum_grad.add(&arg_grad)?
709                    }
710                    Op::Transpose(arg, dim1, dim2) => {
711                        let arg_grad = grad.transpose(*dim1, *dim2)?;
712                        let sum_grad = grads.or_insert(arg)?;
713                        *sum_grad = sum_grad.add(&arg_grad)?
714                    }
715                    Op::Permute(arg, dims) => {
716                        let mut inv_dims = vec![0; dims.len()];
717                        for (i, &dim_idx) in dims.iter().enumerate() {
718                            inv_dims[dim_idx] = i
719                        }
720                        let arg_grad = grad.permute(inv_dims)?;
721                        let sum_grad = grads.or_insert(arg)?;
722                        *sum_grad = sum_grad.add(&arg_grad)?
723                    }
724                };
725            }
726        }
727        Ok(grads)
728    }
729}
730
731/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
732#[derive(Debug)]
733pub struct GradStore(HashMap<TensorId, Tensor>);
734
735impl GradStore {
736    /// Create a new gradient store
737    fn new() -> Self {
738        GradStore(HashMap::new())
739    }
740
741    /// Get the gradient tensor corresponding to the given tensor id
742    pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
743        self.0.get(&id)
744    }
745
746    /// Get the gradient tensor associated with the given tensor
747    pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
748        self.0.get(&tensor.id())
749    }
750
751    /// Remove the gradient tensor associated with the given tensor, returning it if it exists
752    pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
753        self.0.remove(&tensor.id())
754    }
755
756    /// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
757    pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
758        self.0.insert(tensor.id(), grad)
759    }
760
761    /// Insert a gradient tensor associated with the given tensor id, returning the previous gradient tensor if it existed
762    pub fn insert_id(&mut self, id: TensorId, grad: Tensor) -> Option<Tensor> {
763        self.0.insert(id, grad)
764    }
765
766    /// Get the gradient tensor associated with the given tensor, or, if it does not exist,
767    /// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
768    fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
769        use std::collections::hash_map::Entry;
770        let grad = match self.0.entry(tensor.id()) {
771            Entry::Occupied(entry) => entry.into_mut(),
772            Entry::Vacant(entry) => {
773                let grad = tensor.zeros_like()?;
774                entry.insert(grad)
775            }
776        };
777        Ok(grad)
778    }
779
780    /// Get the tensor ids of the stored gradient tensors
781    pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
782        self.0.keys()
783    }
784}