1use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
3use crate::{Error, Result, Tensor, TensorId};
4use std::collections::HashMap;
5
6fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
9 if arg.rank() == node.rank() {
10 node.broadcast_as(arg.shape())
12 } else {
13 node.reshape(reduced_dims)?.broadcast_as(arg.shape())
16 }
17}
18
19thread_local! {
20 static CANDLE_GRAD_DO_NOT_DETACH: bool = {
21 match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
22 Ok(s) => {
23 !s.is_empty() && s != "0"
24 },
25 Err(_) => false,
26 }
27 }
28}
29
30impl Tensor {
31 pub fn sorted_nodes(&self) -> Vec<&Tensor> {
36 fn walk<'a>(
39 node: &'a Tensor,
40 nodes: Vec<&'a Tensor>,
41 already_seen: &mut HashMap<TensorId, bool>,
42 ) -> (bool, Vec<&'a Tensor>) {
43 if let Some(&tg) = already_seen.get(&node.id()) {
44 return (tg, nodes);
45 }
46 let mut track_grad = false;
47 let mut nodes = if node.is_variable() {
48 track_grad = true;
50 nodes
51 } else if node.dtype().is_int() {
52 nodes
53 } else if let Some(op) = node.op() {
54 match op {
55 Op::IndexAdd(t1, t2, t3, _)
56 | Op::Scatter(t1, t2, t3, _)
57 | Op::ScatterAdd(t1, t2, t3, _)
58 | Op::CustomOp3(t1, t2, t3, _)
59 | Op::WhereCond(t1, t2, t3) => {
60 let (tg, nodes) = walk(t1, nodes, already_seen);
61 track_grad |= tg;
62 let (tg, nodes) = walk(t2, nodes, already_seen);
63 track_grad |= tg;
64 let (tg, nodes) = walk(t3, nodes, already_seen);
65 track_grad |= tg;
66 nodes
67 }
68 Op::Conv1D {
69 arg: lhs,
70 kernel: rhs,
71 ..
72 }
73 | Op::ConvTranspose1D {
74 arg: lhs,
75 kernel: rhs,
76 ..
77 }
78 | Op::Conv2D {
79 arg: lhs,
80 kernel: rhs,
81 ..
82 }
83 | Op::ConvTranspose2D {
84 arg: lhs,
85 kernel: rhs,
86 ..
87 }
88 | Op::CustomOp2(lhs, rhs, _)
89 | Op::Binary(lhs, rhs, _)
90 | Op::Gather(lhs, rhs, _)
91 | Op::IndexSelect(lhs, rhs, _)
92 | Op::Matmul(lhs, rhs)
93 | Op::SliceScatter0(lhs, rhs, _) => {
94 let (tg, nodes) = walk(lhs, nodes, already_seen);
95 track_grad |= tg;
96 let (tg, nodes) = walk(rhs, nodes, already_seen);
97 track_grad |= tg;
98 nodes
99 }
100 Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
101 let (tg, nodes) = walk(arg, nodes, already_seen);
102 track_grad |= tg;
103 nodes
104 }),
105 Op::Affine { arg, mul, .. } => {
106 if *mul == 0. {
107 nodes
108 } else {
109 let (tg, nodes) = walk(arg, nodes, already_seen);
110 track_grad |= tg;
111 nodes
112 }
113 }
114 Op::Unary(_node, UnaryOp::Ceil)
115 | Op::Unary(_node, UnaryOp::Floor)
116 | Op::Unary(_node, UnaryOp::Round)
117 | Op::Unary(_node, UnaryOp::Sign) => nodes,
118 Op::Reshape(node)
119 | Op::UpsampleNearest1D { arg: node, .. }
120 | Op::UpsampleNearest2D { arg: node, .. }
121 | Op::UpsampleBilinear2D { arg: node, .. }
122 | Op::AvgPool2D { arg: node, .. }
123 | Op::MaxPool2D { arg: node, .. }
124 | Op::Copy(node)
125 | Op::Broadcast(node)
126 | Op::Cmp(node, _)
127 | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
128 | Op::ToDevice(node)
129 | Op::Transpose(node, _, _)
130 | Op::Permute(node, _)
131 | Op::Narrow(node, _, _, _)
132 | Op::Unary(node, _)
133 | Op::Elu(node, _)
134 | Op::Powf(node, _)
135 | Op::CustomOp1(node, _) => {
136 let (tg, nodes) = walk(node, nodes, already_seen);
137 track_grad |= tg;
138 nodes
139 }
140 Op::ToDType(node) => {
141 if node.dtype().is_float() {
142 let (tg, nodes) = walk(node, nodes, already_seen);
143 track_grad |= tg;
144 nodes
145 } else {
146 nodes
147 }
148 }
149 Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
150 }
151 } else {
152 nodes
153 };
154 already_seen.insert(node.id(), track_grad);
155 if track_grad {
156 nodes.push(node);
157 }
158 (track_grad, nodes)
159 }
160 let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
161 nodes.reverse();
162 nodes
163 }
164
165 pub fn backward(&self) -> Result<GradStore> {
166 let sorted_nodes = self.sorted_nodes();
167 let mut grads = GradStore::new();
168 grads.insert(self, self.ones_like()?.contiguous()?);
169 for node in sorted_nodes.iter() {
170 if node.is_variable() {
171 continue;
172 }
173 let grad = grads
174 .remove(node)
175 .expect("candle internal error - grad not populated");
176 let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
182 let grad = if do_not_detach { grad } else { grad.detach() };
183 if let Some(op) = node.op() {
184 match op {
185 Op::Binary(lhs, rhs, BinaryOp::Add) => {
186 let lhs_sum_grad = grads.or_insert(lhs)?;
187 *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
188 let rhs_sum_grad = grads.or_insert(rhs)?;
189 *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
190 }
191 Op::Binary(lhs, rhs, BinaryOp::Sub) => {
192 let lhs_sum_grad = grads.or_insert(lhs)?;
193 *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
194 let rhs_sum_grad = grads.or_insert(rhs)?;
195 *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
196 }
197 Op::Binary(lhs, rhs, BinaryOp::Mul) => {
198 let lhs_grad = grad.mul(rhs)?;
199 let lhs_sum_grad = grads.or_insert(lhs)?;
200 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
201 let rhs_grad = grad.mul(lhs)?;
202 let rhs_sum_grad = grads.or_insert(rhs)?;
203 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
204 }
205 Op::Binary(lhs, rhs, BinaryOp::Div) => {
206 let lhs_grad = grad.div(rhs)?;
207 let lhs_sum_grad = grads.or_insert(lhs)?;
208 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
209 let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
210 let rhs_sum_grad = grads.or_insert(rhs)?;
211 *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
212 }
213 Op::Binary(lhs, rhs, BinaryOp::Minimum)
214 | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
215 let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
216 let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
217
218 let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
221 let lhs_sum_grad = grads.or_insert(lhs)?;
222 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
223
224 let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
225 let rhs_sum_grad = grads.or_insert(rhs)?;
226 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
227 }
228 Op::WhereCond(pred, t, f) => {
229 let zeros = grad.zeros_like()?;
230 let t_sum_grad = grads.or_insert(t)?;
231 let t_grad = pred.where_cond(&grad, &zeros)?;
232 *t_sum_grad = t_sum_grad.add(&t_grad)?;
233 let f_sum_grad = grads.or_insert(f)?;
234 let f_grad = pred.where_cond(&zeros, &grad)?;
235 *f_sum_grad = f_sum_grad.add(&f_grad)?;
236 }
237 Op::Conv1D {
238 arg,
239 kernel,
240 padding,
241 stride,
242 dilation,
243 } => {
244 let grad_l_in = grad.dim(2)?;
247 let k_size = kernel.dim(2)?;
248 let out_size =
249 (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
250 let out_padding = arg.dim(2)? - out_size;
251 let grad_arg = grad.conv_transpose1d(
252 kernel,
253 *padding,
254 out_padding,
255 *stride,
256 *dilation,
257 1,
258 )?;
259 let sum_grad = grads.or_insert(arg)?;
260 *sum_grad = sum_grad.add(&grad_arg)?;
261
262 let grad_kernel = arg
263 .transpose(0, 1)?
264 .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
265 .transpose(0, 1)?;
266 let sum_grad = grads.or_insert(kernel)?;
267 let (_, _, k0) = kernel.dims3()?;
268 let (_, _, g_k0) = grad_kernel.dims3()?;
269 let grad_kernel = if g_k0 != k0 {
270 grad_kernel.narrow(2, 0, k0)?
271 } else {
272 grad_kernel
273 };
274 *sum_grad = sum_grad.add(&grad_kernel)?;
275 }
276 Op::Conv2D {
277 arg,
278 kernel,
279 padding,
280 stride,
281 dilation,
282 } => {
283 let grad_h = grad.dim(2)?;
286 let k_h = kernel.dim(2)?;
287 let out_size =
288 (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
289 let out_padding = arg.dim(2)? - out_size;
290 let grad_arg = grad.conv_transpose2d(
291 kernel,
292 *padding,
293 out_padding,
294 *stride,
295 *dilation,
296 )?;
297 let sum_grad = grads.or_insert(arg)?;
298 *sum_grad = sum_grad.add(&grad_arg)?;
299
300 let grad_kernel = arg
301 .transpose(0, 1)?
302 .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
303 .transpose(0, 1)?;
304 let sum_grad = grads.or_insert(kernel)?;
305 let (_, _, k0, k1) = kernel.dims4()?;
306 let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
307 let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
308 grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
309 } else {
310 grad_kernel
311 };
312 *sum_grad = sum_grad.add(&grad_kernel)?;
313 }
314 Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
315 op: "conv-transpose1d",
316 })?,
317 Op::ConvTranspose2D {
318 arg,
319 kernel,
320 padding,
321 stride,
322 dilation,
323 output_padding: _output_padding,
324 } => {
325 let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
326 let sum_grad = grads.or_insert(arg)?;
327 *sum_grad = sum_grad.add(&grad_arg)?;
328
329 let grad_kernel = grad
330 .transpose(0, 1)?
331 .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
332 .transpose(0, 1)?;
333 let sum_grad = grads.or_insert(kernel)?;
334 let (_, _, k0, k1) = kernel.dims4()?;
335 let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
336 let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
337 grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
338 } else {
339 grad_kernel
340 };
341 *sum_grad = sum_grad.add(&grad_kernel)?;
342 }
343 Op::AvgPool2D {
344 arg,
345 kernel_size,
346 stride,
347 } => {
348 if kernel_size != stride {
349 crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
350 }
351 let (_n, _c, h, w) = arg.dims4()?;
352 let grad_arg = grad.upsample_nearest2d(h, w)?;
353 let grad_arg =
354 (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
355 let sum_grad = grads.or_insert(arg)?;
356 *sum_grad = sum_grad.add(&grad_arg)?;
357 }
358 Op::MaxPool2D {
359 arg,
360 kernel_size,
361 stride,
362 } => {
363 if kernel_size != stride {
364 crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
365 }
366 let (_n, _c, h, w) = arg.dims4()?;
367 let node_upsampled = node.upsample_nearest2d(h, w)?;
372 let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
373 let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
374 let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
375 let sum_grad = grads.or_insert(arg)?;
376 *sum_grad = sum_grad.add(&grad_arg)?;
377 }
378 Op::UpsampleNearest1D { arg, target_size } => {
379 let (_n, c, size) = arg.dims3()?;
380 if target_size % size != 0 {
381 crate::bail!("backward not supported for non integer upscaling factors")
382 }
383 let scale = target_size / size;
384
385 let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
386 let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
387 let sum_grad = grads.or_insert(arg)?;
388 *sum_grad = conv_sum;
389 }
390 Op::UpsampleNearest2D {
391 arg,
392 target_h,
393 target_w,
394 } => {
395 let (_n, c, h, w) = arg.dims4()?;
396 if target_h % h != 0 || target_w % w != 0 {
397 crate::bail!("backward not supported for non integer upscaling factors")
398 }
399 let scale_h = target_h / h;
400 let scale_w = target_w / w;
401
402 if scale_h != scale_w {
403 crate::bail!("backward not supported for non uniform upscaling factors")
404 };
405 let kernel =
406 Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
407 let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
408 let sum_grad = grads.or_insert(arg)?;
409 *sum_grad = conv_sum;
410 }
411 Op::UpsampleBilinear2D { .. } => {
412 crate::bail!("backward not supported for upsample_bilinear2d")
413 }
414 Op::SliceScatter0(lhs, rhs, start_rhs) => {
415 let rhs_sum_grad = grads.or_insert(rhs)?;
416 let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
417 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
418
419 let lhs_sum_grad = grads.or_insert(lhs)?;
420 let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
421 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
422 }
423 Op::Gather(arg, indexes, dim) => {
424 let sum_grad = grads.or_insert(arg)?;
425 *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
426 }
427 Op::Scatter(init, indexes, src, dim) => {
428 let init_sum_grad = grads.or_insert(init)?;
429 *init_sum_grad = init_sum_grad.add(&grad)?;
430
431 let src_grad = grad.gather(indexes, *dim)?;
432 let src_sum_grad = grads.or_insert(src)?;
433 *src_sum_grad = src_sum_grad.add(&src_grad)?;
434 }
435 Op::ScatterAdd(init, indexes, src, dim) => {
436 let init_sum_grad = grads.or_insert(init)?;
437 let mask = init.ones_like()?;
438 let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
439 *init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
440
441 let src_grad = grad.gather(indexes, *dim)?;
442 let src_sum_grad = grads.or_insert(src)?;
443 *src_sum_grad = src_sum_grad.add(&src_grad)?;
444 }
445 Op::IndexAdd(init, indexes, src, dim) => {
446 let init_sum_grad = grads.or_insert(init)?;
447 *init_sum_grad = init_sum_grad.add(&grad)?;
448
449 let src_grad = grad.index_select(indexes, *dim)?;
450 let src_sum_grad = grads.or_insert(src)?;
451 *src_sum_grad = src_sum_grad.add(&src_grad)?;
452 }
453 Op::IndexSelect(arg, indexes, dim) => {
454 let sum_grad = grads.or_insert(arg)?;
455 *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
456 }
457 Op::Matmul(lhs, rhs) => {
458 let lhs_grad = grad.matmul(&rhs.t()?)?;
462 let lhs_sum_grad = grads.or_insert(lhs)?;
463 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
464
465 let rhs_grad = lhs.t()?.matmul(&grad)?;
466 let rhs_sum_grad = grads.or_insert(rhs)?;
467 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
468 }
469 Op::Cat(args, dim) => {
470 let mut start_idx = 0;
471 for arg in args {
472 let len = arg.dims()[*dim];
473 let arg_grad = grad.narrow(*dim, start_idx, len)?;
474 let sum_grad = grads.or_insert(arg)?;
475 *sum_grad = sum_grad.add(&arg_grad)?;
476 start_idx += len;
477 }
478 }
479 Op::Broadcast(arg) => {
480 let arg_dims = arg.dims();
481 let node_dims = node.dims();
482 let left_dims = node_dims.len() - arg_dims.len();
484 let mut sum_dims: Vec<usize> = (0..left_dims).collect();
485 for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
486 .iter()
487 .zip(arg_dims.iter())
488 .enumerate()
489 {
490 if node_dim != arg_dim {
491 sum_dims.push(dim + left_dims)
492 }
493 }
494
495 let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
496 for _i in 0..left_dims {
497 arg_grad = arg_grad.squeeze(0)?
498 }
499 let sum_grad = grads.or_insert(arg)?;
500 *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
501 }
502 Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
503 let grad = broadcast_back(arg, &grad, reduced_dims)?;
504 let sum_grad = grads.or_insert(arg)?;
505 *sum_grad = sum_grad.add(&grad)?;
506 }
507 Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
508 let node = broadcast_back(arg, node, reduced_dims)?;
509 let grad = broadcast_back(arg, &grad, reduced_dims)?;
510 let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
511 let sum_grad = grads.or_insert(arg)?;
512 *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
513 }
514 Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
515 let node = broadcast_back(arg, node, reduced_dims)?;
516 let grad = broadcast_back(arg, &grad, reduced_dims)?;
517 let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
518 let sum_grad = grads.or_insert(arg)?;
519 *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
520 }
521 Op::ToDType(arg) => {
522 let sum_grad = grads.or_insert(arg)?;
523 *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
524 }
525 Op::Copy(arg) => {
526 let sum_grad = grads.or_insert(arg)?;
527 *sum_grad = sum_grad.add(&grad)?
528 }
529 Op::Affine { arg, mul, .. } => {
530 let arg_grad = grad.affine(*mul, 0.)?;
531 let sum_grad = grads.or_insert(arg)?;
532 *sum_grad = sum_grad.add(&arg_grad)?
533 }
534 Op::Unary(arg, UnaryOp::Log) => {
535 let sum_grad = grads.or_insert(arg)?;
536 *sum_grad = sum_grad.add(&(grad / arg)?)?
537 }
538 Op::Unary(arg, UnaryOp::Sin) => {
539 let sum_grad = grads.or_insert(arg)?;
540 *sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
541 }
542 Op::Unary(arg, UnaryOp::Cos) => {
543 let sum_grad = grads.or_insert(arg)?;
544 *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
545 }
546 Op::Unary(arg, UnaryOp::Tanh) => {
547 let sum_grad = grads.or_insert(arg)?;
548 let minus_dtanh = (node.sqr()? - 1.)?;
549 *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
550 }
551 Op::Unary(arg, UnaryOp::Abs) => {
552 let sum_grad = grads.or_insert(arg)?;
553 let ones = arg.ones_like()?;
554 let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
555 *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
556 }
557 Op::Unary(arg, UnaryOp::Exp) => {
558 let sum_grad = grads.or_insert(arg)?;
559 *sum_grad = sum_grad.add(&(&grad * *node)?)?
560 }
561 Op::Unary(arg, UnaryOp::Neg) => {
562 let sum_grad = grads.or_insert(arg)?;
563 *sum_grad = sum_grad.sub(&grad)?
564 }
565 Op::Unary(arg, UnaryOp::Recip) => {
566 let sum_grad = grads.or_insert(arg)?;
567 let grad = (grad / arg.sqr()?)?;
568 *sum_grad = sum_grad.sub(&grad)?
569 }
570 &Op::Narrow(ref arg, dim, start_idx, len) => {
571 let arg_dims = arg.dims();
572 let left_pad = if start_idx == 0 {
573 None
574 } else {
575 let mut dims = arg_dims.to_vec();
576 dims[dim] = start_idx;
577 Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
578 };
579 let right_pad = arg_dims[dim] - start_idx - len;
580 let right_pad = if right_pad == 0 {
581 None
582 } else {
583 let mut dims = arg_dims.to_vec();
584 dims[dim] = right_pad;
585 Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
586 };
587 let arg_grad = match (left_pad, right_pad) {
588 (None, None) => grad,
589 (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
590 (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
591 (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
592 };
593 let sum_grad = grads.or_insert(arg)?;
594 *sum_grad = sum_grad.add(&arg_grad)?
595 }
596 Op::Unary(_, UnaryOp::Floor)
597 | Op::Unary(_, UnaryOp::Round)
598 | Op::Reduce(_, ReduceOp::ArgMin, _)
599 | Op::Reduce(_, ReduceOp::ArgMax, _)
600 | Op::Unary(_, UnaryOp::Sign)
601 | Op::Cmp(_, _) => {}
602 Op::Reshape(arg) => {
603 let arg_grad = grad.reshape(arg.dims())?;
604 let sum_grad = grads.or_insert(arg)?;
605 *sum_grad = sum_grad.add(&arg_grad)?
606 }
607 Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
608 Op::Unary(arg, UnaryOp::Gelu) => {
609 let sum_grad = grads.or_insert(arg)?;
610 let cube = arg.powf(3.)?;
611 let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
612 let gelu_grad = (((0.5 * &tanh)?
613 + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
614 + 0.5)?;
615 *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
616 }
617 Op::Unary(arg, UnaryOp::Erf) => {
618 let sum_grad = grads.or_insert(arg)?;
619 let erf_grad =
621 (2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
622 *sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
623 }
624 Op::Unary(arg, UnaryOp::GeluErf) => {
625 let sum_grad = grads.or_insert(arg)?;
626 let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
628 let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
629 let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
630 let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
631 let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
632 *sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
633 }
634 Op::Unary(arg, UnaryOp::Relu) => {
635 let sum_grad = grads.or_insert(arg)?;
636 let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
637 *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
638 }
639 Op::Unary(arg, UnaryOp::Silu) => {
640 let sum_grad = grads.or_insert(arg)?;
641 let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
643 let silu_grad = &sigmoid_arg * (1. - *node) + *node;
644 *sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
645 }
646 Op::Elu(arg, alpha) => {
647 let sum_grad = grads.or_insert(arg)?;
649 let zeros = arg.zeros_like()?;
650 let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
651 let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
652 let negative_exp_mask = (negative_mask * (*node + *alpha))?;
654 let combined_mask = (positive_mask + negative_exp_mask)?;
655 *sum_grad = sum_grad.add(&(grad * combined_mask)?)?
656 }
657 Op::Powf(arg, e) => {
658 let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
659 let sum_grad = grads.or_insert(arg)?;
660 *sum_grad = sum_grad.add(&arg_grad)?
661 }
662 Op::CustomOp1(arg, c) => {
663 if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
664 let sum_grad = grads.or_insert(arg)?;
665 *sum_grad = sum_grad.add(&arg_grad)?
666 }
667 }
668 Op::CustomOp2(arg1, arg2, c) => {
669 let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
670 if let Some(arg_grad1) = arg_grad1 {
671 let sum_grad = grads.or_insert(arg1)?;
672 *sum_grad = sum_grad.add(&arg_grad1)?
673 }
674 if let Some(arg_grad2) = arg_grad2 {
675 let sum_grad = grads.or_insert(arg2)?;
676 *sum_grad = sum_grad.add(&arg_grad2)?
677 }
678 }
679 Op::CustomOp3(arg1, arg2, arg3, c) => {
680 let (arg_grad1, arg_grad2, arg_grad3) =
681 c.bwd(arg1, arg2, arg3, node, &grad)?;
682 if let Some(arg_grad1) = arg_grad1 {
683 let sum_grad = grads.or_insert(arg1)?;
684 *sum_grad = sum_grad.add(&arg_grad1)?
685 }
686 if let Some(arg_grad2) = arg_grad2 {
687 let sum_grad = grads.or_insert(arg2)?;
688 *sum_grad = sum_grad.add(&arg_grad2)?
689 }
690 if let Some(arg_grad3) = arg_grad3 {
691 let sum_grad = grads.or_insert(arg3)?;
692 *sum_grad = sum_grad.add(&arg_grad3)?
693 }
694 }
695 Op::Unary(arg, UnaryOp::Sqr) => {
696 let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
697 let sum_grad = grads.or_insert(arg)?;
698 *sum_grad = sum_grad.add(&arg_grad)?
699 }
700 Op::Unary(arg, UnaryOp::Sqrt) => {
701 let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
702 let sum_grad = grads.or_insert(arg)?;
703 *sum_grad = sum_grad.add(&arg_grad)?
704 }
705 Op::ToDevice(arg) => {
706 let sum_grad = grads.or_insert(arg)?;
707 let arg_grad = grad.to_device(sum_grad.device())?;
708 *sum_grad = sum_grad.add(&arg_grad)?
709 }
710 Op::Transpose(arg, dim1, dim2) => {
711 let arg_grad = grad.transpose(*dim1, *dim2)?;
712 let sum_grad = grads.or_insert(arg)?;
713 *sum_grad = sum_grad.add(&arg_grad)?
714 }
715 Op::Permute(arg, dims) => {
716 let mut inv_dims = vec![0; dims.len()];
717 for (i, &dim_idx) in dims.iter().enumerate() {
718 inv_dims[dim_idx] = i
719 }
720 let arg_grad = grad.permute(inv_dims)?;
721 let sum_grad = grads.or_insert(arg)?;
722 *sum_grad = sum_grad.add(&arg_grad)?
723 }
724 };
725 }
726 }
727 Ok(grads)
728 }
729}
730
731#[derive(Debug)]
733pub struct GradStore(HashMap<TensorId, Tensor>);
734
735impl GradStore {
736 fn new() -> Self {
738 GradStore(HashMap::new())
739 }
740
741 pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
743 self.0.get(&id)
744 }
745
746 pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
748 self.0.get(&tensor.id())
749 }
750
751 pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
753 self.0.remove(&tensor.id())
754 }
755
756 pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
758 self.0.insert(tensor.id(), grad)
759 }
760
761 pub fn insert_id(&mut self, id: TensorId, grad: Tensor) -> Option<Tensor> {
763 self.0.insert(id, grad)
764 }
765
766 fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
769 use std::collections::hash_map::Entry;
770 let grad = match self.0.entry(tensor.id()) {
771 Entry::Occupied(entry) => entry.into_mut(),
772 Entry::Vacant(entry) => {
773 let grad = tensor.zeros_like()?;
774 entry.insert(grad)
775 }
776 };
777 Ok(grad)
778 }
779
780 pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
782 self.0.keys()
783 }
784}