1use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
3use crate::{Error, Result, Tensor, TensorId};
4use std::collections::HashMap;
5
6fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
9 if arg.rank() == node.rank() {
10 node.broadcast_as(arg.shape())
12 } else {
13 node.reshape(reduced_dims)?.broadcast_as(arg.shape())
16 }
17}
18
19thread_local! {
20 static CANDLE_GRAD_DO_NOT_DETACH: bool = {
21 match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
22 Ok(s) => {
23 !s.is_empty() && s != "0"
24 },
25 Err(_) => false,
26 }
27 }
28}
29
30impl Tensor {
31 pub fn sorted_nodes(&self) -> Vec<&Tensor> {
36 fn walk<'a>(
39 node: &'a Tensor,
40 nodes: Vec<&'a Tensor>,
41 already_seen: &mut HashMap<TensorId, bool>,
42 ) -> (bool, Vec<&'a Tensor>) {
43 if let Some(&tg) = already_seen.get(&node.id()) {
44 return (tg, nodes);
45 }
46 let mut track_grad = false;
47 let mut nodes = if node.is_variable() {
48 track_grad = true;
50 nodes
51 } else if node.dtype().is_int() {
52 nodes
53 } else if let Some(op) = node.op() {
54 match op {
55 Op::IndexAdd(t1, t2, t3, _)
56 | Op::Scatter(t1, t2, t3, _)
57 | Op::ScatterAdd(t1, t2, t3, _)
58 | Op::CustomOp3(t1, t2, t3, _)
59 | Op::WhereCond(t1, t2, t3) => {
60 let (tg, nodes) = walk(t1, nodes, already_seen);
61 track_grad |= tg;
62 let (tg, nodes) = walk(t2, nodes, already_seen);
63 track_grad |= tg;
64 let (tg, nodes) = walk(t3, nodes, already_seen);
65 track_grad |= tg;
66 nodes
67 }
68 Op::Conv1D {
69 arg: lhs,
70 kernel: rhs,
71 ..
72 }
73 | Op::ConvTranspose1D {
74 arg: lhs,
75 kernel: rhs,
76 ..
77 }
78 | Op::Conv2D {
79 arg: lhs,
80 kernel: rhs,
81 ..
82 }
83 | Op::ConvTranspose2D {
84 arg: lhs,
85 kernel: rhs,
86 ..
87 }
88 | Op::CustomOp2(lhs, rhs, _)
89 | Op::Binary(lhs, rhs, _)
90 | Op::Gather(lhs, rhs, _)
91 | Op::IndexSelect(lhs, rhs, _)
92 | Op::Matmul(lhs, rhs)
93 | Op::SliceScatter0(lhs, rhs, _) => {
94 let (tg, nodes) = walk(lhs, nodes, already_seen);
95 track_grad |= tg;
96 let (tg, nodes) = walk(rhs, nodes, already_seen);
97 track_grad |= tg;
98 nodes
99 }
100 Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
101 let (tg, nodes) = walk(arg, nodes, already_seen);
102 track_grad |= tg;
103 nodes
104 }),
105 Op::Affine { arg, mul, .. } => {
106 if *mul == 0. {
107 nodes
108 } else {
109 let (tg, nodes) = walk(arg, nodes, already_seen);
110 track_grad |= tg;
111 nodes
112 }
113 }
114 Op::Unary(_node, UnaryOp::Ceil)
115 | Op::Unary(_node, UnaryOp::Floor)
116 | Op::Unary(_node, UnaryOp::Round)
117 | Op::Unary(_node, UnaryOp::Sign) => nodes,
118 Op::Reshape(node)
119 | Op::UpsampleNearest1D { arg: node, .. }
120 | Op::UpsampleNearest2D { arg: node, .. }
121 | Op::AvgPool2D { arg: node, .. }
122 | Op::MaxPool2D { arg: node, .. }
123 | Op::Copy(node)
124 | Op::Broadcast(node)
125 | Op::Cmp(node, _)
126 | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
127 | Op::ToDevice(node)
128 | Op::Transpose(node, _, _)
129 | Op::Permute(node, _)
130 | Op::Narrow(node, _, _, _)
131 | Op::Unary(node, _)
132 | Op::Elu(node, _)
133 | Op::Powf(node, _)
134 | Op::CustomOp1(node, _) => {
135 let (tg, nodes) = walk(node, nodes, already_seen);
136 track_grad |= tg;
137 nodes
138 }
139 Op::ToDType(node) => {
140 if node.dtype().is_float() {
141 let (tg, nodes) = walk(node, nodes, already_seen);
142 track_grad |= tg;
143 nodes
144 } else {
145 nodes
146 }
147 }
148 Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
149 }
150 } else {
151 nodes
152 };
153 already_seen.insert(node.id(), track_grad);
154 if track_grad {
155 nodes.push(node);
156 }
157 (track_grad, nodes)
158 }
159 let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
160 nodes.reverse();
161 nodes
162 }
163
164 pub fn backward(&self) -> Result<GradStore> {
165 let sorted_nodes = self.sorted_nodes();
166 let mut grads = GradStore::new();
167 grads.insert(self, self.ones_like()?.contiguous()?);
168 for node in sorted_nodes.iter() {
169 if node.is_variable() {
170 continue;
171 }
172 let grad = grads
173 .remove(node)
174 .expect("candle internal error - grad not populated");
175 let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
181 let grad = if do_not_detach { grad } else { grad.detach() };
182 if let Some(op) = node.op() {
183 match op {
184 Op::Binary(lhs, rhs, BinaryOp::Add) => {
185 let lhs_sum_grad = grads.or_insert(lhs)?;
186 *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
187 let rhs_sum_grad = grads.or_insert(rhs)?;
188 *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
189 }
190 Op::Binary(lhs, rhs, BinaryOp::Sub) => {
191 let lhs_sum_grad = grads.or_insert(lhs)?;
192 *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
193 let rhs_sum_grad = grads.or_insert(rhs)?;
194 *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
195 }
196 Op::Binary(lhs, rhs, BinaryOp::Mul) => {
197 let lhs_grad = grad.mul(rhs)?;
198 let lhs_sum_grad = grads.or_insert(lhs)?;
199 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
200 let rhs_grad = grad.mul(lhs)?;
201 let rhs_sum_grad = grads.or_insert(rhs)?;
202 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
203 }
204 Op::Binary(lhs, rhs, BinaryOp::Div) => {
205 let lhs_grad = grad.div(rhs)?;
206 let lhs_sum_grad = grads.or_insert(lhs)?;
207 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
208 let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
209 let rhs_sum_grad = grads.or_insert(rhs)?;
210 *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
211 }
212 Op::Binary(lhs, rhs, BinaryOp::Minimum)
213 | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
214 let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
215 let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
216
217 let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
220 let lhs_sum_grad = grads.or_insert(lhs)?;
221 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
222
223 let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
224 let rhs_sum_grad = grads.or_insert(rhs)?;
225 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
226 }
227 Op::WhereCond(pred, t, f) => {
228 let zeros = grad.zeros_like()?;
229 let t_sum_grad = grads.or_insert(t)?;
230 let t_grad = pred.where_cond(&grad, &zeros)?;
231 *t_sum_grad = t_sum_grad.add(&t_grad)?;
232 let f_sum_grad = grads.or_insert(f)?;
233 let f_grad = pred.where_cond(&zeros, &grad)?;
234 *f_sum_grad = f_sum_grad.add(&f_grad)?;
235 }
236 Op::Conv1D {
237 arg,
238 kernel,
239 padding,
240 stride,
241 dilation,
242 } => {
243 let grad_l_in = grad.dim(2)?;
246 let k_size = kernel.dim(2)?;
247 let out_size =
248 (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
249 let out_padding = arg.dim(2)? - out_size;
250 let grad_arg = grad.conv_transpose1d(
251 kernel,
252 *padding,
253 out_padding,
254 *stride,
255 *dilation,
256 1,
257 )?;
258 let sum_grad = grads.or_insert(arg)?;
259 *sum_grad = sum_grad.add(&grad_arg)?;
260
261 let grad_kernel = arg
262 .transpose(0, 1)?
263 .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
264 .transpose(0, 1)?;
265 let sum_grad = grads.or_insert(kernel)?;
266 let (_, _, k0) = kernel.dims3()?;
267 let (_, _, g_k0) = grad_kernel.dims3()?;
268 let grad_kernel = if g_k0 != k0 {
269 grad_kernel.narrow(2, 0, k0)?
270 } else {
271 grad_kernel
272 };
273 *sum_grad = sum_grad.add(&grad_kernel)?;
274 }
275 Op::Conv2D {
276 arg,
277 kernel,
278 padding,
279 stride,
280 dilation,
281 } => {
282 let grad_h = grad.dim(2)?;
285 let k_h = kernel.dim(2)?;
286 let out_size =
287 (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
288 let out_padding = arg.dim(2)? - out_size;
289 let grad_arg = grad.conv_transpose2d(
290 kernel,
291 *padding,
292 out_padding,
293 *stride,
294 *dilation,
295 )?;
296 let sum_grad = grads.or_insert(arg)?;
297 *sum_grad = sum_grad.add(&grad_arg)?;
298
299 let grad_kernel = arg
300 .transpose(0, 1)?
301 .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
302 .transpose(0, 1)?;
303 let sum_grad = grads.or_insert(kernel)?;
304 let (_, _, k0, k1) = kernel.dims4()?;
305 let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
306 let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
307 grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
308 } else {
309 grad_kernel
310 };
311 *sum_grad = sum_grad.add(&grad_kernel)?;
312 }
313 Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
314 op: "conv-transpose1d",
315 })?,
316 Op::ConvTranspose2D {
317 arg,
318 kernel,
319 padding,
320 stride,
321 dilation,
322 output_padding: _output_padding,
323 } => {
324 let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
325 let sum_grad = grads.or_insert(arg)?;
326 *sum_grad = sum_grad.add(&grad_arg)?;
327
328 let grad_kernel = grad
329 .transpose(0, 1)?
330 .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
331 .transpose(0, 1)?;
332 let sum_grad = grads.or_insert(kernel)?;
333 let (_, _, k0, k1) = kernel.dims4()?;
334 let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
335 let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
336 grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
337 } else {
338 grad_kernel
339 };
340 *sum_grad = sum_grad.add(&grad_kernel)?;
341 }
342 Op::AvgPool2D {
343 arg,
344 kernel_size,
345 stride,
346 } => {
347 if kernel_size != stride {
348 crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
349 }
350 let (_n, _c, h, w) = arg.dims4()?;
351 let grad_arg = grad.upsample_nearest2d(h, w)?;
352 let grad_arg =
353 (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
354 let sum_grad = grads.or_insert(arg)?;
355 *sum_grad = sum_grad.add(&grad_arg)?;
356 }
357 Op::MaxPool2D {
358 arg,
359 kernel_size,
360 stride,
361 } => {
362 if kernel_size != stride {
363 crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
364 }
365 let (_n, _c, h, w) = arg.dims4()?;
366 let node_upsampled = node.upsample_nearest2d(h, w)?;
371 let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
372 let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
373 let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
374 let sum_grad = grads.or_insert(arg)?;
375 *sum_grad = sum_grad.add(&grad_arg)?;
376 }
377 Op::UpsampleNearest1D { arg, target_size } => {
378 let (_n, c, size) = arg.dims3()?;
379 if target_size % size != 0 {
380 crate::bail!("backward not supported for non integer upscaling factors")
381 }
382 let scale = target_size / size;
383
384 let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
385 let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
386 let sum_grad = grads.or_insert(arg)?;
387 *sum_grad = conv_sum;
388 }
389 Op::UpsampleNearest2D {
390 arg,
391 target_h,
392 target_w,
393 } => {
394 let (_n, c, h, w) = arg.dims4()?;
395 if target_h % h != 0 || target_w % w != 0 {
396 crate::bail!("backward not supported for non integer upscaling factors")
397 }
398 let scale_h = target_h / h;
399 let scale_w = target_w / w;
400
401 if scale_h != scale_w {
402 crate::bail!("backward not supported for non uniform upscaling factors")
403 };
404 let kernel =
405 Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
406 let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
407 let sum_grad = grads.or_insert(arg)?;
408 *sum_grad = conv_sum;
409 }
410 Op::SliceScatter0(lhs, rhs, start_rhs) => {
411 let rhs_sum_grad = grads.or_insert(rhs)?;
412 let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
413 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
414
415 let lhs_sum_grad = grads.or_insert(lhs)?;
416 let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
417 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
418 }
419 Op::Gather(arg, indexes, dim) => {
420 let sum_grad = grads.or_insert(arg)?;
421 *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
422 }
423 Op::Scatter(init, indexes, src, dim) => {
424 let init_sum_grad = grads.or_insert(init)?;
425 *init_sum_grad = init_sum_grad.add(&grad)?;
426
427 let src_grad = grad.gather(indexes, *dim)?;
428 let src_sum_grad = grads.or_insert(src)?;
429 *src_sum_grad = src_sum_grad.add(&src_grad)?;
430 }
431 Op::ScatterAdd(init, indexes, src, dim) => {
432 let init_sum_grad = grads.or_insert(init)?;
433 let mask = init.ones_like()?;
434 let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
435 *init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
436
437 let src_grad = grad.gather(indexes, *dim)?;
438 let src_sum_grad = grads.or_insert(src)?;
439 *src_sum_grad = src_sum_grad.add(&src_grad)?;
440 }
441 Op::IndexAdd(init, indexes, src, dim) => {
442 let init_sum_grad = grads.or_insert(init)?;
443 *init_sum_grad = init_sum_grad.add(&grad)?;
444
445 let src_grad = grad.index_select(indexes, *dim)?;
446 let src_sum_grad = grads.or_insert(src)?;
447 *src_sum_grad = src_sum_grad.add(&src_grad)?;
448 }
449 Op::IndexSelect(arg, indexes, dim) => {
450 let sum_grad = grads.or_insert(arg)?;
451 *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
452 }
453 Op::Matmul(lhs, rhs) => {
454 let lhs_grad = grad.matmul(&rhs.t()?)?;
458 let lhs_sum_grad = grads.or_insert(lhs)?;
459 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
460
461 let rhs_grad = lhs.t()?.matmul(&grad)?;
462 let rhs_sum_grad = grads.or_insert(rhs)?;
463 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
464 }
465 Op::Cat(args, dim) => {
466 let mut start_idx = 0;
467 for arg in args {
468 let len = arg.dims()[*dim];
469 let arg_grad = grad.narrow(*dim, start_idx, len)?;
470 let sum_grad = grads.or_insert(arg)?;
471 *sum_grad = sum_grad.add(&arg_grad)?;
472 start_idx += len;
473 }
474 }
475 Op::Broadcast(arg) => {
476 let arg_dims = arg.dims();
477 let node_dims = node.dims();
478 let left_dims = node_dims.len() - arg_dims.len();
480 let mut sum_dims: Vec<usize> = (0..left_dims).collect();
481 for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
482 .iter()
483 .zip(arg_dims.iter())
484 .enumerate()
485 {
486 if node_dim != arg_dim {
487 sum_dims.push(dim + left_dims)
488 }
489 }
490
491 let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
492 for _i in 0..left_dims {
493 arg_grad = arg_grad.squeeze(0)?
494 }
495 let sum_grad = grads.or_insert(arg)?;
496 *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
497 }
498 Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
499 let grad = broadcast_back(arg, &grad, reduced_dims)?;
500 let sum_grad = grads.or_insert(arg)?;
501 *sum_grad = sum_grad.add(&grad)?;
502 }
503 Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
504 let node = broadcast_back(arg, node, reduced_dims)?;
505 let grad = broadcast_back(arg, &grad, reduced_dims)?;
506 let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
507 let sum_grad = grads.or_insert(arg)?;
508 *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
509 }
510 Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
511 let node = broadcast_back(arg, node, reduced_dims)?;
512 let grad = broadcast_back(arg, &grad, reduced_dims)?;
513 let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
514 let sum_grad = grads.or_insert(arg)?;
515 *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
516 }
517 Op::ToDType(arg) => {
518 let sum_grad = grads.or_insert(arg)?;
519 *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
520 }
521 Op::Copy(arg) => {
522 let sum_grad = grads.or_insert(arg)?;
523 *sum_grad = sum_grad.add(&grad)?
524 }
525 Op::Affine { arg, mul, .. } => {
526 let arg_grad = grad.affine(*mul, 0.)?;
527 let sum_grad = grads.or_insert(arg)?;
528 *sum_grad = sum_grad.add(&arg_grad)?
529 }
530 Op::Unary(arg, UnaryOp::Log) => {
531 let sum_grad = grads.or_insert(arg)?;
532 *sum_grad = sum_grad.add(&(grad / arg)?)?
533 }
534 Op::Unary(arg, UnaryOp::Sin) => {
535 let sum_grad = grads.or_insert(arg)?;
536 *sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
537 }
538 Op::Unary(arg, UnaryOp::Cos) => {
539 let sum_grad = grads.or_insert(arg)?;
540 *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
541 }
542 Op::Unary(arg, UnaryOp::Tanh) => {
543 let sum_grad = grads.or_insert(arg)?;
544 let minus_dtanh = (node.sqr()? - 1.)?;
545 *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
546 }
547 Op::Unary(arg, UnaryOp::Abs) => {
548 let sum_grad = grads.or_insert(arg)?;
549 let ones = arg.ones_like()?;
550 let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
551 *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
552 }
553 Op::Unary(arg, UnaryOp::Exp) => {
554 let sum_grad = grads.or_insert(arg)?;
555 *sum_grad = sum_grad.add(&(&grad * *node)?)?
556 }
557 Op::Unary(arg, UnaryOp::Neg) => {
558 let sum_grad = grads.or_insert(arg)?;
559 *sum_grad = sum_grad.sub(&grad)?
560 }
561 Op::Unary(arg, UnaryOp::Recip) => {
562 let sum_grad = grads.or_insert(arg)?;
563 let grad = (grad / arg.sqr()?)?;
564 *sum_grad = sum_grad.sub(&grad)?
565 }
566 &Op::Narrow(ref arg, dim, start_idx, len) => {
567 let arg_dims = arg.dims();
568 let left_pad = if start_idx == 0 {
569 None
570 } else {
571 let mut dims = arg_dims.to_vec();
572 dims[dim] = start_idx;
573 Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
574 };
575 let right_pad = arg_dims[dim] - start_idx - len;
576 let right_pad = if right_pad == 0 {
577 None
578 } else {
579 let mut dims = arg_dims.to_vec();
580 dims[dim] = right_pad;
581 Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
582 };
583 let arg_grad = match (left_pad, right_pad) {
584 (None, None) => grad,
585 (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
586 (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
587 (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
588 };
589 let sum_grad = grads.or_insert(arg)?;
590 *sum_grad = sum_grad.add(&arg_grad)?
591 }
592 Op::Unary(_, UnaryOp::Floor)
593 | Op::Unary(_, UnaryOp::Round)
594 | Op::Reduce(_, ReduceOp::ArgMin, _)
595 | Op::Reduce(_, ReduceOp::ArgMax, _)
596 | Op::Unary(_, UnaryOp::Sign)
597 | Op::Cmp(_, _) => {}
598 Op::Reshape(arg) => {
599 let arg_grad = grad.reshape(arg.dims())?;
600 let sum_grad = grads.or_insert(arg)?;
601 *sum_grad = sum_grad.add(&arg_grad)?
602 }
603 Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
604 Op::Unary(arg, UnaryOp::Gelu) => {
605 let sum_grad = grads.or_insert(arg)?;
606 let cube = arg.powf(3.)?;
607 let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
608 let gelu_grad = (((0.5 * &tanh)?
609 + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
610 + 0.5)?;
611 *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
612 }
613 Op::Unary(arg, UnaryOp::Erf) => {
614 let sum_grad = grads.or_insert(arg)?;
615 let erf_grad =
617 (2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
618 *sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
619 }
620 Op::Unary(arg, UnaryOp::GeluErf) => {
621 let sum_grad = grads.or_insert(arg)?;
622 let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
624 let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
625 let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
626 let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
627 let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
628 *sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
629 }
630 Op::Unary(arg, UnaryOp::Relu) => {
631 let sum_grad = grads.or_insert(arg)?;
632 let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
633 *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
634 }
635 Op::Unary(arg, UnaryOp::Silu) => {
636 let sum_grad = grads.or_insert(arg)?;
637 let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
639 let silu_grad = &sigmoid_arg * (1. - *node) + *node;
640 *sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
641 }
642 Op::Elu(arg, alpha) => {
643 let sum_grad = grads.or_insert(arg)?;
645 let zeros = arg.zeros_like()?;
646 let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
647 let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
648 let negative_exp_mask = (negative_mask * (*node + *alpha))?;
650 let combined_mask = (positive_mask + negative_exp_mask)?;
651 *sum_grad = sum_grad.add(&(grad * combined_mask)?)?
652 }
653 Op::Powf(arg, e) => {
654 let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
655 let sum_grad = grads.or_insert(arg)?;
656 *sum_grad = sum_grad.add(&arg_grad)?
657 }
658 Op::CustomOp1(arg, c) => {
659 if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
660 let sum_grad = grads.or_insert(arg)?;
661 *sum_grad = sum_grad.add(&arg_grad)?
662 }
663 }
664 Op::CustomOp2(arg1, arg2, c) => {
665 let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
666 if let Some(arg_grad1) = arg_grad1 {
667 let sum_grad = grads.or_insert(arg1)?;
668 *sum_grad = sum_grad.add(&arg_grad1)?
669 }
670 if let Some(arg_grad2) = arg_grad2 {
671 let sum_grad = grads.or_insert(arg2)?;
672 *sum_grad = sum_grad.add(&arg_grad2)?
673 }
674 }
675 Op::CustomOp3(arg1, arg2, arg3, c) => {
676 let (arg_grad1, arg_grad2, arg_grad3) =
677 c.bwd(arg1, arg2, arg3, node, &grad)?;
678 if let Some(arg_grad1) = arg_grad1 {
679 let sum_grad = grads.or_insert(arg1)?;
680 *sum_grad = sum_grad.add(&arg_grad1)?
681 }
682 if let Some(arg_grad2) = arg_grad2 {
683 let sum_grad = grads.or_insert(arg2)?;
684 *sum_grad = sum_grad.add(&arg_grad2)?
685 }
686 if let Some(arg_grad3) = arg_grad3 {
687 let sum_grad = grads.or_insert(arg3)?;
688 *sum_grad = sum_grad.add(&arg_grad3)?
689 }
690 }
691 Op::Unary(arg, UnaryOp::Sqr) => {
692 let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
693 let sum_grad = grads.or_insert(arg)?;
694 *sum_grad = sum_grad.add(&arg_grad)?
695 }
696 Op::Unary(arg, UnaryOp::Sqrt) => {
697 let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
698 let sum_grad = grads.or_insert(arg)?;
699 *sum_grad = sum_grad.add(&arg_grad)?
700 }
701 Op::ToDevice(arg) => {
702 let sum_grad = grads.or_insert(arg)?;
703 let arg_grad = grad.to_device(sum_grad.device())?;
704 *sum_grad = sum_grad.add(&arg_grad)?
705 }
706 Op::Transpose(arg, dim1, dim2) => {
707 let arg_grad = grad.transpose(*dim1, *dim2)?;
708 let sum_grad = grads.or_insert(arg)?;
709 *sum_grad = sum_grad.add(&arg_grad)?
710 }
711 Op::Permute(arg, dims) => {
712 let mut inv_dims = vec![0; dims.len()];
713 for (i, &dim_idx) in dims.iter().enumerate() {
714 inv_dims[dim_idx] = i
715 }
716 let arg_grad = grad.permute(inv_dims)?;
717 let sum_grad = grads.or_insert(arg)?;
718 *sum_grad = sum_grad.add(&arg_grad)?
719 }
720 };
721 }
722 }
723 Ok(grads)
724 }
725}
726
727#[derive(Debug)]
729pub struct GradStore(HashMap<TensorId, Tensor>);
730
731impl GradStore {
732 fn new() -> Self {
734 GradStore(HashMap::new())
735 }
736
737 pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
739 self.0.get(&id)
740 }
741
742 pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
744 self.0.get(&tensor.id())
745 }
746
747 pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
749 self.0.remove(&tensor.id())
750 }
751
752 pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
754 self.0.insert(tensor.id(), grad)
755 }
756
757 fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
760 use std::collections::hash_map::Entry;
761 let grad = match self.0.entry(tensor.id()) {
762 Entry::Occupied(entry) => entry.into_mut(),
763 Entry::Vacant(entry) => {
764 let grad = tensor.zeros_like()?;
765 entry.insert(grad)
766 }
767 };
768 Ok(grad)
769 }
770
771 pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
773 self.0.keys()
774 }
775}