candle_core/
backprop.rs

1//! Methods for backpropagation of gradients.
2use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
3use crate::{Error, Result, Tensor, TensorId};
4use std::collections::HashMap;
5
6// arg has been reduced to node via reduce_dims, expand it back to arg.
7// This has to handle keepdims.
8fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
9    if arg.rank() == node.rank() {
10        // keepdim = true
11        node.broadcast_as(arg.shape())
12    } else {
13        // keepdim = false
14        // first expand the reduced dims.
15        node.reshape(reduced_dims)?.broadcast_as(arg.shape())
16    }
17}
18
19thread_local! {
20    static CANDLE_GRAD_DO_NOT_DETACH: bool = {
21        match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
22            Ok(s) => {
23                !s.is_empty() && s != "0"
24            },
25            Err(_) => false,
26        }
27    }
28}
29
30impl Tensor {
31    /// Return all the nodes that lead to this value in a topologically sorted vec, the first
32    /// elements having dependencies on the latter ones, e.g. the first element if any is the
33    /// argument.
34    /// This assumes that the op graph is a DAG.
35    pub fn sorted_nodes(&self) -> Vec<&Tensor> {
36        // The vec of sorted nodes is passed as an owned value rather than a mutable reference
37        // to get around some lifetime limitations.
38        fn walk<'a>(
39            node: &'a Tensor,
40            nodes: Vec<&'a Tensor>,
41            already_seen: &mut HashMap<TensorId, bool>,
42        ) -> (bool, Vec<&'a Tensor>) {
43            if let Some(&tg) = already_seen.get(&node.id()) {
44                return (tg, nodes);
45            }
46            let mut track_grad = false;
47            let mut nodes = if node.is_variable() {
48                // Do not call recursively on the "leaf" nodes.
49                track_grad = true;
50                nodes
51            } else if node.dtype().is_int() {
52                nodes
53            } else if let Some(op) = node.op() {
54                match op {
55                    Op::IndexAdd(t1, t2, t3, _)
56                    | Op::Scatter(t1, t2, t3, _)
57                    | Op::ScatterAdd(t1, t2, t3, _)
58                    | Op::CustomOp3(t1, t2, t3, _)
59                    | Op::WhereCond(t1, t2, t3) => {
60                        let (tg, nodes) = walk(t1, nodes, already_seen);
61                        track_grad |= tg;
62                        let (tg, nodes) = walk(t2, nodes, already_seen);
63                        track_grad |= tg;
64                        let (tg, nodes) = walk(t3, nodes, already_seen);
65                        track_grad |= tg;
66                        nodes
67                    }
68                    Op::Conv1D {
69                        arg: lhs,
70                        kernel: rhs,
71                        ..
72                    }
73                    | Op::ConvTranspose1D {
74                        arg: lhs,
75                        kernel: rhs,
76                        ..
77                    }
78                    | Op::Conv2D {
79                        arg: lhs,
80                        kernel: rhs,
81                        ..
82                    }
83                    | Op::ConvTranspose2D {
84                        arg: lhs,
85                        kernel: rhs,
86                        ..
87                    }
88                    | Op::CustomOp2(lhs, rhs, _)
89                    | Op::Binary(lhs, rhs, _)
90                    | Op::Gather(lhs, rhs, _)
91                    | Op::IndexSelect(lhs, rhs, _)
92                    | Op::Matmul(lhs, rhs)
93                    | Op::SliceScatter0(lhs, rhs, _) => {
94                        let (tg, nodes) = walk(lhs, nodes, already_seen);
95                        track_grad |= tg;
96                        let (tg, nodes) = walk(rhs, nodes, already_seen);
97                        track_grad |= tg;
98                        nodes
99                    }
100                    Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
101                        let (tg, nodes) = walk(arg, nodes, already_seen);
102                        track_grad |= tg;
103                        nodes
104                    }),
105                    Op::Affine { arg, mul, .. } => {
106                        if *mul == 0. {
107                            nodes
108                        } else {
109                            let (tg, nodes) = walk(arg, nodes, already_seen);
110                            track_grad |= tg;
111                            nodes
112                        }
113                    }
114                    Op::Unary(_node, UnaryOp::Ceil)
115                    | Op::Unary(_node, UnaryOp::Floor)
116                    | Op::Unary(_node, UnaryOp::Round)
117                    | Op::Unary(_node, UnaryOp::Sign) => nodes,
118                    Op::Reshape(node)
119                    | Op::UpsampleNearest1D { arg: node, .. }
120                    | Op::UpsampleNearest2D { arg: node, .. }
121                    | Op::AvgPool2D { arg: node, .. }
122                    | Op::MaxPool2D { arg: node, .. }
123                    | Op::Copy(node)
124                    | Op::Broadcast(node)
125                    | Op::Cmp(node, _)
126                    | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
127                    | Op::ToDevice(node)
128                    | Op::Transpose(node, _, _)
129                    | Op::Permute(node, _)
130                    | Op::Narrow(node, _, _, _)
131                    | Op::Unary(node, _)
132                    | Op::Elu(node, _)
133                    | Op::Powf(node, _)
134                    | Op::CustomOp1(node, _) => {
135                        let (tg, nodes) = walk(node, nodes, already_seen);
136                        track_grad |= tg;
137                        nodes
138                    }
139                    Op::ToDType(node) => {
140                        if node.dtype().is_float() {
141                            let (tg, nodes) = walk(node, nodes, already_seen);
142                            track_grad |= tg;
143                            nodes
144                        } else {
145                            nodes
146                        }
147                    }
148                    Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
149                }
150            } else {
151                nodes
152            };
153            already_seen.insert(node.id(), track_grad);
154            if track_grad {
155                nodes.push(node);
156            }
157            (track_grad, nodes)
158        }
159        let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
160        nodes.reverse();
161        nodes
162    }
163
164    pub fn backward(&self) -> Result<GradStore> {
165        let sorted_nodes = self.sorted_nodes();
166        let mut grads = GradStore::new();
167        grads.insert(self, self.ones_like()?.contiguous()?);
168        for node in sorted_nodes.iter() {
169            if node.is_variable() {
170                continue;
171            }
172            let grad = grads
173                .remove(node)
174                .expect("candle internal error - grad not populated");
175            // https://github.com/huggingface/candle/issues/1241
176            // Ideally, we would make these operations in place where possible to ensure that we
177            // do not have to allocate too often. Here we just call `.detach` to avoid computing
178            // the backprop graph of the backprop itself. This would be an issue for second order
179            // derivatives but these are out of scope at the moment.
180            let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
181            let grad = if do_not_detach { grad } else { grad.detach() };
182            if let Some(op) = node.op() {
183                match op {
184                    Op::Binary(lhs, rhs, BinaryOp::Add) => {
185                        let lhs_sum_grad = grads.or_insert(lhs)?;
186                        *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
187                        let rhs_sum_grad = grads.or_insert(rhs)?;
188                        *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
189                    }
190                    Op::Binary(lhs, rhs, BinaryOp::Sub) => {
191                        let lhs_sum_grad = grads.or_insert(lhs)?;
192                        *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
193                        let rhs_sum_grad = grads.or_insert(rhs)?;
194                        *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
195                    }
196                    Op::Binary(lhs, rhs, BinaryOp::Mul) => {
197                        let lhs_grad = grad.mul(rhs)?;
198                        let lhs_sum_grad = grads.or_insert(lhs)?;
199                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
200                        let rhs_grad = grad.mul(lhs)?;
201                        let rhs_sum_grad = grads.or_insert(rhs)?;
202                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
203                    }
204                    Op::Binary(lhs, rhs, BinaryOp::Div) => {
205                        let lhs_grad = grad.div(rhs)?;
206                        let lhs_sum_grad = grads.or_insert(lhs)?;
207                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
208                        let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
209                        let rhs_sum_grad = grads.or_insert(rhs)?;
210                        *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
211                    }
212                    Op::Binary(lhs, rhs, BinaryOp::Minimum)
213                    | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
214                        let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
215                        let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
216
217                        // If both masks are 1 one the same point, we want to scale the
218                        // gradient by 0.5 rather than 1.
219                        let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
220                        let lhs_sum_grad = grads.or_insert(lhs)?;
221                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
222
223                        let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
224                        let rhs_sum_grad = grads.or_insert(rhs)?;
225                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
226                    }
227                    Op::WhereCond(pred, t, f) => {
228                        let zeros = grad.zeros_like()?;
229                        let t_sum_grad = grads.or_insert(t)?;
230                        let t_grad = pred.where_cond(&grad, &zeros)?;
231                        *t_sum_grad = t_sum_grad.add(&t_grad)?;
232                        let f_sum_grad = grads.or_insert(f)?;
233                        let f_grad = pred.where_cond(&zeros, &grad)?;
234                        *f_sum_grad = f_sum_grad.add(&f_grad)?;
235                    }
236                    Op::Conv1D {
237                        arg,
238                        kernel,
239                        padding,
240                        stride,
241                        dilation,
242                    } => {
243                        // The output height for conv_transpose1d is:
244                        // (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1
245                        let grad_l_in = grad.dim(2)?;
246                        let k_size = kernel.dim(2)?;
247                        let out_size =
248                            (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
249                        let out_padding = arg.dim(2)? - out_size;
250                        let grad_arg = grad.conv_transpose1d(
251                            kernel,
252                            *padding,
253                            out_padding,
254                            *stride,
255                            *dilation,
256                            /* groups */ 1,
257                        )?;
258                        let sum_grad = grads.or_insert(arg)?;
259                        *sum_grad = sum_grad.add(&grad_arg)?;
260
261                        let grad_kernel = arg
262                            .transpose(0, 1)?
263                            .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
264                            .transpose(0, 1)?;
265                        let sum_grad = grads.or_insert(kernel)?;
266                        let (_, _, k0) = kernel.dims3()?;
267                        let (_, _, g_k0) = grad_kernel.dims3()?;
268                        let grad_kernel = if g_k0 != k0 {
269                            grad_kernel.narrow(2, 0, k0)?
270                        } else {
271                            grad_kernel
272                        };
273                        *sum_grad = sum_grad.add(&grad_kernel)?;
274                    }
275                    Op::Conv2D {
276                        arg,
277                        kernel,
278                        padding,
279                        stride,
280                        dilation,
281                    } => {
282                        // The output height for conv_transpose2d is:
283                        // (i_h - 1) * stride - 2 * padding + dilation * (k_h - 1) + out_padding + 1
284                        let grad_h = grad.dim(2)?;
285                        let k_h = kernel.dim(2)?;
286                        let out_size =
287                            (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
288                        let out_padding = arg.dim(2)? - out_size;
289                        let grad_arg = grad.conv_transpose2d(
290                            kernel,
291                            *padding,
292                            out_padding,
293                            *stride,
294                            *dilation,
295                        )?;
296                        let sum_grad = grads.or_insert(arg)?;
297                        *sum_grad = sum_grad.add(&grad_arg)?;
298
299                        let grad_kernel = arg
300                            .transpose(0, 1)?
301                            .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
302                            .transpose(0, 1)?;
303                        let sum_grad = grads.or_insert(kernel)?;
304                        let (_, _, k0, k1) = kernel.dims4()?;
305                        let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
306                        let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
307                            grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
308                        } else {
309                            grad_kernel
310                        };
311                        *sum_grad = sum_grad.add(&grad_kernel)?;
312                    }
313                    Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
314                        op: "conv-transpose1d",
315                    })?,
316                    Op::ConvTranspose2D {
317                        arg,
318                        kernel,
319                        padding,
320                        stride,
321                        dilation,
322                        output_padding: _output_padding,
323                    } => {
324                        let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
325                        let sum_grad = grads.or_insert(arg)?;
326                        *sum_grad = sum_grad.add(&grad_arg)?;
327
328                        let grad_kernel = grad
329                            .transpose(0, 1)?
330                            .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
331                            .transpose(0, 1)?;
332                        let sum_grad = grads.or_insert(kernel)?;
333                        let (_, _, k0, k1) = kernel.dims4()?;
334                        let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
335                        let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
336                            grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
337                        } else {
338                            grad_kernel
339                        };
340                        *sum_grad = sum_grad.add(&grad_kernel)?;
341                    }
342                    Op::AvgPool2D {
343                        arg,
344                        kernel_size,
345                        stride,
346                    } => {
347                        if kernel_size != stride {
348                            crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
349                        }
350                        let (_n, _c, h, w) = arg.dims4()?;
351                        let grad_arg = grad.upsample_nearest2d(h, w)?;
352                        let grad_arg =
353                            (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
354                        let sum_grad = grads.or_insert(arg)?;
355                        *sum_grad = sum_grad.add(&grad_arg)?;
356                    }
357                    Op::MaxPool2D {
358                        arg,
359                        kernel_size,
360                        stride,
361                    } => {
362                        if kernel_size != stride {
363                            crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
364                        }
365                        let (_n, _c, h, w) = arg.dims4()?;
366                        // For computing the max-pool gradient, we compute a mask where a 1 means
367                        // that the element is the maximum, then we apply this mask to the
368                        // upsampled gradient (taking into account that multiple max may exist so
369                        // we scale the gradient for this case).
370                        let node_upsampled = node.upsample_nearest2d(h, w)?;
371                        let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
372                        let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
373                        let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
374                        let sum_grad = grads.or_insert(arg)?;
375                        *sum_grad = sum_grad.add(&grad_arg)?;
376                    }
377                    Op::UpsampleNearest1D { arg, target_size } => {
378                        let (_n, c, size) = arg.dims3()?;
379                        if target_size % size != 0 {
380                            crate::bail!("backward not supported for non integer upscaling factors")
381                        }
382                        let scale = target_size / size;
383
384                        let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
385                        let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
386                        let sum_grad = grads.or_insert(arg)?;
387                        *sum_grad = conv_sum;
388                    }
389                    Op::UpsampleNearest2D {
390                        arg,
391                        target_h,
392                        target_w,
393                    } => {
394                        let (_n, c, h, w) = arg.dims4()?;
395                        if target_h % h != 0 || target_w % w != 0 {
396                            crate::bail!("backward not supported for non integer upscaling factors")
397                        }
398                        let scale_h = target_h / h;
399                        let scale_w = target_w / w;
400
401                        if scale_h != scale_w {
402                            crate::bail!("backward not supported for non uniform upscaling factors")
403                        };
404                        let kernel =
405                            Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
406                        let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
407                        let sum_grad = grads.or_insert(arg)?;
408                        *sum_grad = conv_sum;
409                    }
410                    Op::SliceScatter0(lhs, rhs, start_rhs) => {
411                        let rhs_sum_grad = grads.or_insert(rhs)?;
412                        let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
413                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
414
415                        let lhs_sum_grad = grads.or_insert(lhs)?;
416                        let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
417                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
418                    }
419                    Op::Gather(arg, indexes, dim) => {
420                        let sum_grad = grads.or_insert(arg)?;
421                        *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
422                    }
423                    Op::Scatter(init, indexes, src, dim) => {
424                        let init_sum_grad = grads.or_insert(init)?;
425                        *init_sum_grad = init_sum_grad.add(&grad)?;
426
427                        let src_grad = grad.gather(indexes, *dim)?;
428                        let src_sum_grad = grads.or_insert(src)?;
429                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
430                    }
431                    Op::ScatterAdd(init, indexes, src, dim) => {
432                        let init_sum_grad = grads.or_insert(init)?;
433                        let mask = init.ones_like()?;
434                        let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
435                        *init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
436
437                        let src_grad = grad.gather(indexes, *dim)?;
438                        let src_sum_grad = grads.or_insert(src)?;
439                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
440                    }
441                    Op::IndexAdd(init, indexes, src, dim) => {
442                        let init_sum_grad = grads.or_insert(init)?;
443                        *init_sum_grad = init_sum_grad.add(&grad)?;
444
445                        let src_grad = grad.index_select(indexes, *dim)?;
446                        let src_sum_grad = grads.or_insert(src)?;
447                        *src_sum_grad = src_sum_grad.add(&src_grad)?;
448                    }
449                    Op::IndexSelect(arg, indexes, dim) => {
450                        let sum_grad = grads.or_insert(arg)?;
451                        *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
452                    }
453                    Op::Matmul(lhs, rhs) => {
454                        // Skipping checks, the op went ok, we can skip
455                        // the matmul size checks for now.
456
457                        let lhs_grad = grad.matmul(&rhs.t()?)?;
458                        let lhs_sum_grad = grads.or_insert(lhs)?;
459                        *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
460
461                        let rhs_grad = lhs.t()?.matmul(&grad)?;
462                        let rhs_sum_grad = grads.or_insert(rhs)?;
463                        *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
464                    }
465                    Op::Cat(args, dim) => {
466                        let mut start_idx = 0;
467                        for arg in args {
468                            let len = arg.dims()[*dim];
469                            let arg_grad = grad.narrow(*dim, start_idx, len)?;
470                            let sum_grad = grads.or_insert(arg)?;
471                            *sum_grad = sum_grad.add(&arg_grad)?;
472                            start_idx += len;
473                        }
474                    }
475                    Op::Broadcast(arg) => {
476                        let arg_dims = arg.dims();
477                        let node_dims = node.dims();
478                        // The number of dims that have been inserted on the left.
479                        let left_dims = node_dims.len() - arg_dims.len();
480                        let mut sum_dims: Vec<usize> = (0..left_dims).collect();
481                        for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
482                            .iter()
483                            .zip(arg_dims.iter())
484                            .enumerate()
485                        {
486                            if node_dim != arg_dim {
487                                sum_dims.push(dim + left_dims)
488                            }
489                        }
490
491                        let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
492                        for _i in 0..left_dims {
493                            arg_grad = arg_grad.squeeze(0)?
494                        }
495                        let sum_grad = grads.or_insert(arg)?;
496                        *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
497                    }
498                    Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
499                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
500                        let sum_grad = grads.or_insert(arg)?;
501                        *sum_grad = sum_grad.add(&grad)?;
502                    }
503                    Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
504                        let node = broadcast_back(arg, node, reduced_dims)?;
505                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
506                        let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
507                        let sum_grad = grads.or_insert(arg)?;
508                        *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
509                    }
510                    Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
511                        let node = broadcast_back(arg, node, reduced_dims)?;
512                        let grad = broadcast_back(arg, &grad, reduced_dims)?;
513                        let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
514                        let sum_grad = grads.or_insert(arg)?;
515                        *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
516                    }
517                    Op::ToDType(arg) => {
518                        let sum_grad = grads.or_insert(arg)?;
519                        *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
520                    }
521                    Op::Copy(arg) => {
522                        let sum_grad = grads.or_insert(arg)?;
523                        *sum_grad = sum_grad.add(&grad)?
524                    }
525                    Op::Affine { arg, mul, .. } => {
526                        let arg_grad = grad.affine(*mul, 0.)?;
527                        let sum_grad = grads.or_insert(arg)?;
528                        *sum_grad = sum_grad.add(&arg_grad)?
529                    }
530                    Op::Unary(arg, UnaryOp::Log) => {
531                        let sum_grad = grads.or_insert(arg)?;
532                        *sum_grad = sum_grad.add(&(grad / arg)?)?
533                    }
534                    Op::Unary(arg, UnaryOp::Sin) => {
535                        let sum_grad = grads.or_insert(arg)?;
536                        *sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
537                    }
538                    Op::Unary(arg, UnaryOp::Cos) => {
539                        let sum_grad = grads.or_insert(arg)?;
540                        *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
541                    }
542                    Op::Unary(arg, UnaryOp::Tanh) => {
543                        let sum_grad = grads.or_insert(arg)?;
544                        let minus_dtanh = (node.sqr()? - 1.)?;
545                        *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
546                    }
547                    Op::Unary(arg, UnaryOp::Abs) => {
548                        let sum_grad = grads.or_insert(arg)?;
549                        let ones = arg.ones_like()?;
550                        let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
551                        *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
552                    }
553                    Op::Unary(arg, UnaryOp::Exp) => {
554                        let sum_grad = grads.or_insert(arg)?;
555                        *sum_grad = sum_grad.add(&(&grad * *node)?)?
556                    }
557                    Op::Unary(arg, UnaryOp::Neg) => {
558                        let sum_grad = grads.or_insert(arg)?;
559                        *sum_grad = sum_grad.sub(&grad)?
560                    }
561                    Op::Unary(arg, UnaryOp::Recip) => {
562                        let sum_grad = grads.or_insert(arg)?;
563                        let grad = (grad / arg.sqr()?)?;
564                        *sum_grad = sum_grad.sub(&grad)?
565                    }
566                    &Op::Narrow(ref arg, dim, start_idx, len) => {
567                        let arg_dims = arg.dims();
568                        let left_pad = if start_idx == 0 {
569                            None
570                        } else {
571                            let mut dims = arg_dims.to_vec();
572                            dims[dim] = start_idx;
573                            Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
574                        };
575                        let right_pad = arg_dims[dim] - start_idx - len;
576                        let right_pad = if right_pad == 0 {
577                            None
578                        } else {
579                            let mut dims = arg_dims.to_vec();
580                            dims[dim] = right_pad;
581                            Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
582                        };
583                        let arg_grad = match (left_pad, right_pad) {
584                            (None, None) => grad,
585                            (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
586                            (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
587                            (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
588                        };
589                        let sum_grad = grads.or_insert(arg)?;
590                        *sum_grad = sum_grad.add(&arg_grad)?
591                    }
592                    Op::Unary(_, UnaryOp::Floor)
593                    | Op::Unary(_, UnaryOp::Round)
594                    | Op::Reduce(_, ReduceOp::ArgMin, _)
595                    | Op::Reduce(_, ReduceOp::ArgMax, _)
596                    | Op::Unary(_, UnaryOp::Sign)
597                    | Op::Cmp(_, _) => {}
598                    Op::Reshape(arg) => {
599                        let arg_grad = grad.reshape(arg.dims())?;
600                        let sum_grad = grads.or_insert(arg)?;
601                        *sum_grad = sum_grad.add(&arg_grad)?
602                    }
603                    Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
604                    Op::Unary(arg, UnaryOp::Gelu) => {
605                        let sum_grad = grads.or_insert(arg)?;
606                        let cube = arg.powf(3.)?;
607                        let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
608                        let gelu_grad = (((0.5 * &tanh)?
609                            + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
610                            + 0.5)?;
611                        *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
612                    }
613                    Op::Unary(arg, UnaryOp::Erf) => {
614                        let sum_grad = grads.or_insert(arg)?;
615                        // d/dx erf(x) = 2/sqrt(pi) * e^(-x^2)
616                        let erf_grad =
617                            (2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
618                        *sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
619                    }
620                    Op::Unary(arg, UnaryOp::GeluErf) => {
621                        let sum_grad = grads.or_insert(arg)?;
622                        // d/dx gelu_erf(x) = 0.5 + 0.398942 e^(-x^2/2) x + 0.5 erf(x/sqrt(2))
623                        let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
624                        let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
625                        let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
626                        let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
627                        let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
628                        *sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
629                    }
630                    Op::Unary(arg, UnaryOp::Relu) => {
631                        let sum_grad = grads.or_insert(arg)?;
632                        let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
633                        *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
634                    }
635                    Op::Unary(arg, UnaryOp::Silu) => {
636                        let sum_grad = grads.or_insert(arg)?;
637                        // d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x))) = sigmoid(x) * (1 - node) + node
638                        let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
639                        let silu_grad = &sigmoid_arg * (1. - *node) + *node;
640                        *sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
641                    }
642                    Op::Elu(arg, alpha) => {
643                        // d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
644                        let sum_grad = grads.or_insert(arg)?;
645                        let zeros = arg.zeros_like()?;
646                        let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
647                        let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
648                        // node == alpha * (e^x - 1) for x <= 0, reuse it
649                        let negative_exp_mask = (negative_mask * (*node + *alpha))?;
650                        let combined_mask = (positive_mask + negative_exp_mask)?;
651                        *sum_grad = sum_grad.add(&(grad * combined_mask)?)?
652                    }
653                    Op::Powf(arg, e) => {
654                        let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
655                        let sum_grad = grads.or_insert(arg)?;
656                        *sum_grad = sum_grad.add(&arg_grad)?
657                    }
658                    Op::CustomOp1(arg, c) => {
659                        if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
660                            let sum_grad = grads.or_insert(arg)?;
661                            *sum_grad = sum_grad.add(&arg_grad)?
662                        }
663                    }
664                    Op::CustomOp2(arg1, arg2, c) => {
665                        let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
666                        if let Some(arg_grad1) = arg_grad1 {
667                            let sum_grad = grads.or_insert(arg1)?;
668                            *sum_grad = sum_grad.add(&arg_grad1)?
669                        }
670                        if let Some(arg_grad2) = arg_grad2 {
671                            let sum_grad = grads.or_insert(arg2)?;
672                            *sum_grad = sum_grad.add(&arg_grad2)?
673                        }
674                    }
675                    Op::CustomOp3(arg1, arg2, arg3, c) => {
676                        let (arg_grad1, arg_grad2, arg_grad3) =
677                            c.bwd(arg1, arg2, arg3, node, &grad)?;
678                        if let Some(arg_grad1) = arg_grad1 {
679                            let sum_grad = grads.or_insert(arg1)?;
680                            *sum_grad = sum_grad.add(&arg_grad1)?
681                        }
682                        if let Some(arg_grad2) = arg_grad2 {
683                            let sum_grad = grads.or_insert(arg2)?;
684                            *sum_grad = sum_grad.add(&arg_grad2)?
685                        }
686                        if let Some(arg_grad3) = arg_grad3 {
687                            let sum_grad = grads.or_insert(arg3)?;
688                            *sum_grad = sum_grad.add(&arg_grad3)?
689                        }
690                    }
691                    Op::Unary(arg, UnaryOp::Sqr) => {
692                        let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
693                        let sum_grad = grads.or_insert(arg)?;
694                        *sum_grad = sum_grad.add(&arg_grad)?
695                    }
696                    Op::Unary(arg, UnaryOp::Sqrt) => {
697                        let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
698                        let sum_grad = grads.or_insert(arg)?;
699                        *sum_grad = sum_grad.add(&arg_grad)?
700                    }
701                    Op::ToDevice(arg) => {
702                        let sum_grad = grads.or_insert(arg)?;
703                        let arg_grad = grad.to_device(sum_grad.device())?;
704                        *sum_grad = sum_grad.add(&arg_grad)?
705                    }
706                    Op::Transpose(arg, dim1, dim2) => {
707                        let arg_grad = grad.transpose(*dim1, *dim2)?;
708                        let sum_grad = grads.or_insert(arg)?;
709                        *sum_grad = sum_grad.add(&arg_grad)?
710                    }
711                    Op::Permute(arg, dims) => {
712                        let mut inv_dims = vec![0; dims.len()];
713                        for (i, &dim_idx) in dims.iter().enumerate() {
714                            inv_dims[dim_idx] = i
715                        }
716                        let arg_grad = grad.permute(inv_dims)?;
717                        let sum_grad = grads.or_insert(arg)?;
718                        *sum_grad = sum_grad.add(&arg_grad)?
719                    }
720                };
721            }
722        }
723        Ok(grads)
724    }
725}
726
727/// A store for gradients, associating a tensor id to the corresponding gradient tensor, used for back propagation.
728#[derive(Debug)]
729pub struct GradStore(HashMap<TensorId, Tensor>);
730
731impl GradStore {
732    /// Create a new gradient store
733    fn new() -> Self {
734        GradStore(HashMap::new())
735    }
736
737    /// Get the gradient tensor corresponding to the given tensor id
738    pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
739        self.0.get(&id)
740    }
741
742    /// Get the gradient tensor associated with the given tensor
743    pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
744        self.0.get(&tensor.id())
745    }
746
747    /// Remove the gradient tensor associated with the given tensor, returning it if it exists
748    pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
749        self.0.remove(&tensor.id())
750    }
751
752    /// Insert a gradient tensor associated with the given tensor, returning the previous gradient tensor if it existed
753    pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
754        self.0.insert(tensor.id(), grad)
755    }
756
757    /// Get the gradient tensor associated with the given tensor, or, if it does not exist,
758    /// insert a tensor of zeroes, with the same shape and type as the given tensors and return it
759    fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
760        use std::collections::hash_map::Entry;
761        let grad = match self.0.entry(tensor.id()) {
762            Entry::Occupied(entry) => entry.into_mut(),
763            Entry::Vacant(entry) => {
764                let grad = tensor.zeros_like()?;
765                entry.insert(grad)
766            }
767        };
768        Ok(grad)
769    }
770
771    /// Get the tensor ids of the stored gradient tensors
772    pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
773        self.0.keys()
774    }
775}