1use crate::op::{BinaryOp, Op, ReduceOp, UnaryOp};
3use crate::{Error, Result, Tensor, TensorId};
4use std::collections::HashMap;
5
6fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result<Tensor> {
9 if arg.rank() == node.rank() {
10 node.broadcast_as(arg.shape())
12 } else {
13 node.reshape(reduced_dims)?.broadcast_as(arg.shape())
16 }
17}
18
19thread_local! {
20 static CANDLE_GRAD_DO_NOT_DETACH: bool = {
21 match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
22 Ok(s) => {
23 !s.is_empty() && s != "0"
24 },
25 Err(_) => false,
26 }
27 }
28}
29
30impl Tensor {
31 pub fn sorted_nodes(&self) -> Vec<&Tensor> {
36 fn walk<'a>(
39 node: &'a Tensor,
40 nodes: Vec<&'a Tensor>,
41 already_seen: &mut HashMap<TensorId, bool>,
42 ) -> (bool, Vec<&'a Tensor>) {
43 if let Some(&tg) = already_seen.get(&node.id()) {
44 return (tg, nodes);
45 }
46 let mut track_grad = false;
47 let mut nodes = if node.is_variable() {
48 track_grad = true;
50 nodes
51 } else if node.dtype().is_int() {
52 nodes
53 } else if let Some(op) = node.op() {
54 match op {
55 Op::IndexAdd(t1, t2, t3, _)
56 | Op::ScatterAdd(t1, t2, t3, _)
57 | Op::CustomOp3(t1, t2, t3, _)
58 | Op::WhereCond(t1, t2, t3) => {
59 let (tg, nodes) = walk(t1, nodes, already_seen);
60 track_grad |= tg;
61 let (tg, nodes) = walk(t2, nodes, already_seen);
62 track_grad |= tg;
63 let (tg, nodes) = walk(t3, nodes, already_seen);
64 track_grad |= tg;
65 nodes
66 }
67 Op::Conv1D {
68 arg: lhs,
69 kernel: rhs,
70 ..
71 }
72 | Op::ConvTranspose1D {
73 arg: lhs,
74 kernel: rhs,
75 ..
76 }
77 | Op::Conv2D {
78 arg: lhs,
79 kernel: rhs,
80 ..
81 }
82 | Op::ConvTranspose2D {
83 arg: lhs,
84 kernel: rhs,
85 ..
86 }
87 | Op::CustomOp2(lhs, rhs, _)
88 | Op::Binary(lhs, rhs, _)
89 | Op::Gather(lhs, rhs, _)
90 | Op::IndexSelect(lhs, rhs, _)
91 | Op::Matmul(lhs, rhs)
92 | Op::SliceScatter0(lhs, rhs, _) => {
93 let (tg, nodes) = walk(lhs, nodes, already_seen);
94 track_grad |= tg;
95 let (tg, nodes) = walk(rhs, nodes, already_seen);
96 track_grad |= tg;
97 nodes
98 }
99 Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
100 let (tg, nodes) = walk(arg, nodes, already_seen);
101 track_grad |= tg;
102 nodes
103 }),
104 Op::Affine { arg, mul, .. } => {
105 if *mul == 0. {
106 nodes
107 } else {
108 let (tg, nodes) = walk(arg, nodes, already_seen);
109 track_grad |= tg;
110 nodes
111 }
112 }
113 Op::Unary(_node, UnaryOp::Ceil)
114 | Op::Unary(_node, UnaryOp::Floor)
115 | Op::Unary(_node, UnaryOp::Round)
116 | Op::Unary(_node, UnaryOp::Sign) => nodes,
117 Op::Reshape(node)
118 | Op::UpsampleNearest1D { arg: node, .. }
119 | Op::UpsampleNearest2D { arg: node, .. }
120 | Op::AvgPool2D { arg: node, .. }
121 | Op::MaxPool2D { arg: node, .. }
122 | Op::Copy(node)
123 | Op::Broadcast(node)
124 | Op::Cmp(node, _)
125 | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
126 | Op::ToDevice(node)
127 | Op::Transpose(node, _, _)
128 | Op::Permute(node, _)
129 | Op::Narrow(node, _, _, _)
130 | Op::Unary(node, _)
131 | Op::Elu(node, _)
132 | Op::Powf(node, _)
133 | Op::CustomOp1(node, _) => {
134 let (tg, nodes) = walk(node, nodes, already_seen);
135 track_grad |= tg;
136 nodes
137 }
138 Op::ToDType(node) => {
139 if node.dtype().is_float() {
140 let (tg, nodes) = walk(node, nodes, already_seen);
141 track_grad |= tg;
142 nodes
143 } else {
144 nodes
145 }
146 }
147 Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
148 }
149 } else {
150 nodes
151 };
152 already_seen.insert(node.id(), track_grad);
153 if track_grad {
154 nodes.push(node);
155 }
156 (track_grad, nodes)
157 }
158 let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
159 nodes.reverse();
160 nodes
161 }
162
163 pub fn backward(&self) -> Result<GradStore> {
164 let sorted_nodes = self.sorted_nodes();
165 let mut grads = GradStore::new();
166 grads.insert(self, self.ones_like()?.contiguous()?);
167 for node in sorted_nodes.iter() {
168 if node.is_variable() {
169 continue;
170 }
171 let grad = grads
172 .remove(node)
173 .expect("candle internal error - grad not populated");
174 let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
180 let grad = if do_not_detach { grad } else { grad.detach() };
181 if let Some(op) = node.op() {
182 match op {
183 Op::Binary(lhs, rhs, BinaryOp::Add) => {
184 let lhs_sum_grad = grads.or_insert(lhs)?;
185 *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
186 let rhs_sum_grad = grads.or_insert(rhs)?;
187 *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
188 }
189 Op::Binary(lhs, rhs, BinaryOp::Sub) => {
190 let lhs_sum_grad = grads.or_insert(lhs)?;
191 *lhs_sum_grad = lhs_sum_grad.add(&grad)?;
192 let rhs_sum_grad = grads.or_insert(rhs)?;
193 *rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
194 }
195 Op::Binary(lhs, rhs, BinaryOp::Mul) => {
196 let lhs_grad = grad.mul(rhs)?;
197 let lhs_sum_grad = grads.or_insert(lhs)?;
198 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
199 let rhs_grad = grad.mul(lhs)?;
200 let rhs_sum_grad = grads.or_insert(rhs)?;
201 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
202 }
203 Op::Binary(lhs, rhs, BinaryOp::Div) => {
204 let lhs_grad = grad.div(rhs)?;
205 let lhs_sum_grad = grads.or_insert(lhs)?;
206 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
207 let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
208 let rhs_sum_grad = grads.or_insert(rhs)?;
209 *rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
210 }
211 Op::Binary(lhs, rhs, BinaryOp::Minimum)
212 | Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
213 let mask_lhs = node.eq(lhs)?.to_dtype(grad.dtype())?;
214 let mask_rhs = node.eq(rhs)?.to_dtype(grad.dtype())?;
215
216 let lhs_grad = mask_lhs.mul(&grad)?.div(&(&mask_rhs + 1.)?)?;
219 let lhs_sum_grad = grads.or_insert(lhs)?;
220 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
221
222 let rhs_grad = mask_rhs.mul(&grad)?.div(&(&mask_lhs + 1.)?)?;
223 let rhs_sum_grad = grads.or_insert(rhs)?;
224 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
225 }
226 Op::WhereCond(pred, t, f) => {
227 let zeros = grad.zeros_like()?;
228 let t_sum_grad = grads.or_insert(t)?;
229 let t_grad = pred.where_cond(&grad, &zeros)?;
230 *t_sum_grad = t_sum_grad.add(&t_grad)?;
231 let f_sum_grad = grads.or_insert(f)?;
232 let f_grad = pred.where_cond(&zeros, &grad)?;
233 *f_sum_grad = f_sum_grad.add(&f_grad)?;
234 }
235 Op::Conv1D {
236 arg,
237 kernel,
238 padding,
239 stride,
240 dilation,
241 } => {
242 let grad_l_in = grad.dim(2)?;
245 let k_size = kernel.dim(2)?;
246 let out_size =
247 (grad_l_in - 1) * stride + dilation * (k_size - 1) + 1 - 2 * padding;
248 let out_padding = arg.dim(2)? - out_size;
249 let grad_arg = grad.conv_transpose1d(
250 kernel,
251 *padding,
252 out_padding,
253 *stride,
254 *dilation,
255 1,
256 )?;
257 let sum_grad = grads.or_insert(arg)?;
258 *sum_grad = sum_grad.add(&grad_arg)?;
259
260 let grad_kernel = arg
261 .transpose(0, 1)?
262 .conv1d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
263 .transpose(0, 1)?;
264 let sum_grad = grads.or_insert(kernel)?;
265 let (_, _, k0) = kernel.dims3()?;
266 let (_, _, g_k0) = grad_kernel.dims3()?;
267 let grad_kernel = if g_k0 != k0 {
268 grad_kernel.narrow(2, 0, k0)?
269 } else {
270 grad_kernel
271 };
272 *sum_grad = sum_grad.add(&grad_kernel)?;
273 }
274 Op::Conv2D {
275 arg,
276 kernel,
277 padding,
278 stride,
279 dilation,
280 } => {
281 let grad_h = grad.dim(2)?;
284 let k_h = kernel.dim(2)?;
285 let out_size =
286 (grad_h - 1) * stride + dilation * (k_h - 1) + 1 - 2 * padding;
287 let out_padding = arg.dim(2)? - out_size;
288 let grad_arg = grad.conv_transpose2d(
289 kernel,
290 *padding,
291 out_padding,
292 *stride,
293 *dilation,
294 )?;
295 let sum_grad = grads.or_insert(arg)?;
296 *sum_grad = sum_grad.add(&grad_arg)?;
297
298 let grad_kernel = arg
299 .transpose(0, 1)?
300 .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
301 .transpose(0, 1)?;
302 let sum_grad = grads.or_insert(kernel)?;
303 let (_, _, k0, k1) = kernel.dims4()?;
304 let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
305 let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
306 grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
307 } else {
308 grad_kernel
309 };
310 *sum_grad = sum_grad.add(&grad_kernel)?;
311 }
312 Op::ConvTranspose1D { .. } => Err(Error::BackwardNotSupported {
313 op: "conv-transpose1d",
314 })?,
315 Op::ConvTranspose2D {
316 arg,
317 kernel,
318 padding,
319 stride,
320 dilation,
321 output_padding: _output_padding,
322 } => {
323 let grad_arg = grad.conv2d(kernel, *padding, *stride, *dilation, 1)?;
324 let sum_grad = grads.or_insert(arg)?;
325 *sum_grad = sum_grad.add(&grad_arg)?;
326
327 let grad_kernel = grad
328 .transpose(0, 1)?
329 .conv2d(&arg.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
330 .transpose(0, 1)?;
331 let sum_grad = grads.or_insert(kernel)?;
332 let (_, _, k0, k1) = kernel.dims4()?;
333 let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
334 let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
335 grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
336 } else {
337 grad_kernel
338 };
339 *sum_grad = sum_grad.add(&grad_kernel)?;
340 }
341 Op::AvgPool2D {
342 arg,
343 kernel_size,
344 stride,
345 } => {
346 if kernel_size != stride {
347 crate::bail!("backward not supported for avgpool2d if ksize {kernel_size:?} != stride {stride:?}")
348 }
349 let (_n, _c, h, w) = arg.dims4()?;
350 let grad_arg = grad.upsample_nearest2d(h, w)?;
351 let grad_arg =
352 (grad_arg * (1f64 / (kernel_size.0 * kernel_size.1) as f64))?;
353 let sum_grad = grads.or_insert(arg)?;
354 *sum_grad = sum_grad.add(&grad_arg)?;
355 }
356 Op::MaxPool2D {
357 arg,
358 kernel_size,
359 stride,
360 } => {
361 if kernel_size != stride {
362 crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
363 }
364 let (_n, _c, h, w) = arg.dims4()?;
365 let node_upsampled = node.upsample_nearest2d(h, w)?;
370 let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
371 let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
372 let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
373 let sum_grad = grads.or_insert(arg)?;
374 *sum_grad = sum_grad.add(&grad_arg)?;
375 }
376 Op::UpsampleNearest1D { arg, target_size } => {
377 let (_n, c, size) = arg.dims3()?;
378 if target_size % size != 0 {
379 crate::bail!("backward not supported for non integer upscaling factors")
380 }
381 let scale = target_size / size;
382
383 let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
384 let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
385 let sum_grad = grads.or_insert(arg)?;
386 *sum_grad = conv_sum;
387 }
388 Op::UpsampleNearest2D {
389 arg,
390 target_h,
391 target_w,
392 } => {
393 let (_n, c, h, w) = arg.dims4()?;
394 if target_h % h != 0 || target_w % w != 0 {
395 crate::bail!("backward not supported for non integer upscaling factors")
396 }
397 let scale_h = target_h / h;
398 let scale_w = target_w / w;
399
400 if scale_h != scale_w {
401 crate::bail!("backward not supported for non uniform upscaling factors")
402 };
403 let kernel =
404 Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
405 let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
406 let sum_grad = grads.or_insert(arg)?;
407 *sum_grad = conv_sum;
408 }
409 Op::SliceScatter0(lhs, rhs, start_rhs) => {
410 let rhs_sum_grad = grads.or_insert(rhs)?;
411 let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
412 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
413
414 let lhs_sum_grad = grads.or_insert(lhs)?;
415 let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
416 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
417 }
418 Op::Gather(arg, indexes, dim) => {
419 let sum_grad = grads.or_insert(arg)?;
420 *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
421 }
422 Op::ScatterAdd(init, indexes, src, dim) => {
423 let init_sum_grad = grads.or_insert(init)?;
424 *init_sum_grad = init_sum_grad.add(&grad)?;
425
426 let src_grad = grad.gather(indexes, *dim)?;
427 let src_sum_grad = grads.or_insert(src)?;
428 *src_sum_grad = src_sum_grad.add(&src_grad)?;
429 }
430 Op::IndexAdd(init, indexes, src, dim) => {
431 let init_sum_grad = grads.or_insert(init)?;
432 *init_sum_grad = init_sum_grad.add(&grad)?;
433
434 let src_grad = grad.index_select(indexes, *dim)?;
435 let src_sum_grad = grads.or_insert(src)?;
436 *src_sum_grad = src_sum_grad.add(&src_grad)?;
437 }
438 Op::IndexSelect(arg, indexes, dim) => {
439 let sum_grad = grads.or_insert(arg)?;
440 *sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
441 }
442 Op::Matmul(lhs, rhs) => {
443 let lhs_grad = grad.matmul(&rhs.t()?)?;
447 let lhs_sum_grad = grads.or_insert(lhs)?;
448 *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
449
450 let rhs_grad = lhs.t()?.matmul(&grad)?;
451 let rhs_sum_grad = grads.or_insert(rhs)?;
452 *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
453 }
454 Op::Cat(args, dim) => {
455 let mut start_idx = 0;
456 for arg in args {
457 let len = arg.dims()[*dim];
458 let arg_grad = grad.narrow(*dim, start_idx, len)?;
459 let sum_grad = grads.or_insert(arg)?;
460 *sum_grad = sum_grad.add(&arg_grad)?;
461 start_idx += len;
462 }
463 }
464 Op::Broadcast(arg) => {
465 let arg_dims = arg.dims();
466 let node_dims = node.dims();
467 let left_dims = node_dims.len() - arg_dims.len();
469 let mut sum_dims: Vec<usize> = (0..left_dims).collect();
470 for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
471 .iter()
472 .zip(arg_dims.iter())
473 .enumerate()
474 {
475 if node_dim != arg_dim {
476 sum_dims.push(dim + left_dims)
477 }
478 }
479
480 let mut arg_grad = grad.sum_keepdim(sum_dims.as_slice())?;
481 for _i in 0..left_dims {
482 arg_grad = arg_grad.squeeze(0)?
483 }
484 let sum_grad = grads.or_insert(arg)?;
485 *sum_grad = sum_grad.add(&arg_grad.broadcast_as(sum_grad.dims())?)?;
486 }
487 Op::Reduce(arg, ReduceOp::Sum, reduced_dims) => {
488 let grad = broadcast_back(arg, &grad, reduced_dims)?;
489 let sum_grad = grads.or_insert(arg)?;
490 *sum_grad = sum_grad.add(&grad)?;
491 }
492 Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
493 let node = broadcast_back(arg, node, reduced_dims)?;
494 let grad = broadcast_back(arg, &grad, reduced_dims)?;
495 let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
496 let sum_grad = grads.or_insert(arg)?;
497 *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
498 }
499 Op::Reduce(arg, ReduceOp::Min, reduced_dims) => {
500 let node = broadcast_back(arg, node, reduced_dims)?;
501 let grad = broadcast_back(arg, &grad, reduced_dims)?;
502 let grad = node.eq(arg)?.to_dtype(grad.dtype())?.mul(&grad)?;
503 let sum_grad = grads.or_insert(arg)?;
504 *sum_grad = sum_grad.add(&grad.broadcast_as(sum_grad.dims())?)?;
505 }
506 Op::ToDType(arg) => {
507 let sum_grad = grads.or_insert(arg)?;
508 *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
509 }
510 Op::Copy(arg) => {
511 let sum_grad = grads.or_insert(arg)?;
512 *sum_grad = sum_grad.add(&grad)?
513 }
514 Op::Affine { arg, mul, .. } => {
515 let arg_grad = grad.affine(*mul, 0.)?;
516 let sum_grad = grads.or_insert(arg)?;
517 *sum_grad = sum_grad.add(&arg_grad)?
518 }
519 Op::Unary(arg, UnaryOp::Log) => {
520 let sum_grad = grads.or_insert(arg)?;
521 *sum_grad = sum_grad.add(&(grad / arg)?)?
522 }
523 Op::Unary(arg, UnaryOp::Sin) => {
524 let sum_grad = grads.or_insert(arg)?;
525 *sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
526 }
527 Op::Unary(arg, UnaryOp::Cos) => {
528 let sum_grad = grads.or_insert(arg)?;
529 *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
530 }
531 Op::Unary(arg, UnaryOp::Tanh) => {
532 let sum_grad = grads.or_insert(arg)?;
533 let minus_dtanh = (node.sqr()? - 1.)?;
534 *sum_grad = sum_grad.sub(&(&grad * &minus_dtanh)?)?
535 }
536 Op::Unary(arg, UnaryOp::Abs) => {
537 let sum_grad = grads.or_insert(arg)?;
538 let ones = arg.ones_like()?;
539 let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
540 *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
541 }
542 Op::Unary(arg, UnaryOp::Exp) => {
543 let sum_grad = grads.or_insert(arg)?;
544 *sum_grad = sum_grad.add(&(&grad * *node)?)?
545 }
546 Op::Unary(arg, UnaryOp::Neg) => {
547 let sum_grad = grads.or_insert(arg)?;
548 *sum_grad = sum_grad.sub(&grad)?
549 }
550 Op::Unary(arg, UnaryOp::Recip) => {
551 let sum_grad = grads.or_insert(arg)?;
552 let grad = (grad / arg.sqr()?)?;
553 *sum_grad = sum_grad.sub(&grad)?
554 }
555 &Op::Narrow(ref arg, dim, start_idx, len) => {
556 let arg_dims = arg.dims();
557 let left_pad = if start_idx == 0 {
558 None
559 } else {
560 let mut dims = arg_dims.to_vec();
561 dims[dim] = start_idx;
562 Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
563 };
564 let right_pad = arg_dims[dim] - start_idx - len;
565 let right_pad = if right_pad == 0 {
566 None
567 } else {
568 let mut dims = arg_dims.to_vec();
569 dims[dim] = right_pad;
570 Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
571 };
572 let arg_grad = match (left_pad, right_pad) {
573 (None, None) => grad,
574 (Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
575 (None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
576 (Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
577 };
578 let sum_grad = grads.or_insert(arg)?;
579 *sum_grad = sum_grad.add(&arg_grad)?
580 }
581 Op::Unary(_, UnaryOp::Floor)
582 | Op::Unary(_, UnaryOp::Round)
583 | Op::Reduce(_, ReduceOp::ArgMin, _)
584 | Op::Reduce(_, ReduceOp::ArgMax, _)
585 | Op::Unary(_, UnaryOp::Sign)
586 | Op::Cmp(_, _) => {}
587 Op::Reshape(arg) => {
588 let arg_grad = grad.reshape(arg.dims())?;
589 let sum_grad = grads.or_insert(arg)?;
590 *sum_grad = sum_grad.add(&arg_grad)?
591 }
592 Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
593 Op::Unary(arg, UnaryOp::Gelu) => {
594 let sum_grad = grads.or_insert(arg)?;
595 let cube = arg.powf(3.)?;
596 let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
597 let gelu_grad = (((0.5 * &tanh)?
598 + (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
599 + 0.5)?;
600 *sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
601 }
602 Op::Unary(arg, UnaryOp::Erf) => {
603 let sum_grad = grads.or_insert(arg)?;
604 let erf_grad =
606 (2. / std::f64::consts::PI.sqrt()) * (arg.sqr()?.neg()?).exp()?;
607 *sum_grad = sum_grad.add(&(&grad * erf_grad)?)?
608 }
609 Op::Unary(arg, UnaryOp::GeluErf) => {
610 let sum_grad = grads.or_insert(arg)?;
611 let neg_half_square = (arg.sqr()?.neg()? / 2.)?;
613 let scaled_exp_arg = (0.398942 * neg_half_square.exp()? * arg)?;
614 let arg_scaled_sqrt = (arg / 2f64.sqrt())?;
615 let erf_scaled_sqrt = (0.5 * arg_scaled_sqrt.erf()?)?;
616 let gelu_erf_grad = (0.5 + scaled_exp_arg + erf_scaled_sqrt)?;
617 *sum_grad = sum_grad.add(&(&grad * gelu_erf_grad)?)?;
618 }
619 Op::Unary(arg, UnaryOp::Relu) => {
620 let sum_grad = grads.or_insert(arg)?;
621 let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
622 *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
623 }
624 Op::Unary(arg, UnaryOp::Silu) => {
625 let sum_grad = grads.or_insert(arg)?;
626 let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
628 let silu_grad = &sigmoid_arg * (1. - *node) + *node;
629 *sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
630 }
631 Op::Elu(arg, alpha) => {
632 let sum_grad = grads.or_insert(arg)?;
634 let zeros = arg.zeros_like()?;
635 let positive_mask = arg.gt(&zeros)?.to_dtype(arg.dtype())?;
636 let negative_mask = arg.le(&zeros)?.to_dtype(arg.dtype())?;
637 let negative_exp_mask = (negative_mask * (*node + *alpha))?;
639 let combined_mask = (positive_mask + negative_exp_mask)?;
640 *sum_grad = sum_grad.add(&(grad * combined_mask)?)?
641 }
642 Op::Powf(arg, e) => {
643 let arg_grad = (&(grad * arg.powf(e - 1.)?)? * *e)?;
644 let sum_grad = grads.or_insert(arg)?;
645 *sum_grad = sum_grad.add(&arg_grad)?
646 }
647 Op::CustomOp1(arg, c) => {
648 if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
649 let sum_grad = grads.or_insert(arg)?;
650 *sum_grad = sum_grad.add(&arg_grad)?
651 }
652 }
653 Op::CustomOp2(arg1, arg2, c) => {
654 let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
655 if let Some(arg_grad1) = arg_grad1 {
656 let sum_grad = grads.or_insert(arg1)?;
657 *sum_grad = sum_grad.add(&arg_grad1)?
658 }
659 if let Some(arg_grad2) = arg_grad2 {
660 let sum_grad = grads.or_insert(arg2)?;
661 *sum_grad = sum_grad.add(&arg_grad2)?
662 }
663 }
664 Op::CustomOp3(arg1, arg2, arg3, c) => {
665 let (arg_grad1, arg_grad2, arg_grad3) =
666 c.bwd(arg1, arg2, arg3, node, &grad)?;
667 if let Some(arg_grad1) = arg_grad1 {
668 let sum_grad = grads.or_insert(arg1)?;
669 *sum_grad = sum_grad.add(&arg_grad1)?
670 }
671 if let Some(arg_grad2) = arg_grad2 {
672 let sum_grad = grads.or_insert(arg2)?;
673 *sum_grad = sum_grad.add(&arg_grad2)?
674 }
675 if let Some(arg_grad3) = arg_grad3 {
676 let sum_grad = grads.or_insert(arg3)?;
677 *sum_grad = sum_grad.add(&arg_grad3)?
678 }
679 }
680 Op::Unary(arg, UnaryOp::Sqr) => {
681 let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
682 let sum_grad = grads.or_insert(arg)?;
683 *sum_grad = sum_grad.add(&arg_grad)?
684 }
685 Op::Unary(arg, UnaryOp::Sqrt) => {
686 let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
687 let sum_grad = grads.or_insert(arg)?;
688 *sum_grad = sum_grad.add(&arg_grad)?
689 }
690 Op::ToDevice(arg) => {
691 let sum_grad = grads.or_insert(arg)?;
692 let arg_grad = grad.to_device(sum_grad.device())?;
693 *sum_grad = sum_grad.add(&arg_grad)?
694 }
695 Op::Transpose(arg, dim1, dim2) => {
696 let arg_grad = grad.transpose(*dim1, *dim2)?;
697 let sum_grad = grads.or_insert(arg)?;
698 *sum_grad = sum_grad.add(&arg_grad)?
699 }
700 Op::Permute(arg, dims) => {
701 let mut inv_dims = vec![0; dims.len()];
702 for (i, &dim_idx) in dims.iter().enumerate() {
703 inv_dims[dim_idx] = i
704 }
705 let arg_grad = grad.permute(inv_dims)?;
706 let sum_grad = grads.or_insert(arg)?;
707 *sum_grad = sum_grad.add(&arg_grad)?
708 }
709 };
710 }
711 }
712 Ok(grads)
713 }
714}
715
716#[derive(Debug)]
718pub struct GradStore(HashMap<TensorId, Tensor>);
719
720impl GradStore {
721 fn new() -> Self {
723 GradStore(HashMap::new())
724 }
725
726 pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
728 self.0.get(&id)
729 }
730
731 pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
733 self.0.get(&tensor.id())
734 }
735
736 pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
738 self.0.remove(&tensor.id())
739 }
740
741 pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
743 self.0.insert(tensor.id(), grad)
744 }
745
746 fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
749 use std::collections::hash_map::Entry;
750 let grad = match self.0.entry(tensor.id()) {
751 Entry::Occupied(entry) => entry.into_mut(),
752 Entry::Vacant(entry) => {
753 let grad = tensor.zeros_like()?;
754 entry.insert(grad)
755 }
756 };
757 Ok(grad)
758 }
759
760 pub fn get_ids(&self) -> impl Iterator<Item = &TensorId> {
762 self.0.keys()
763 }
764}