1use yscv_kernels::{matmul_2d, mul as kernel_mul};
2use yscv_tensor::Tensor;
3
4use super::checkpoint::CheckpointConfig;
5use super::error::AutogradError;
6use super::graph::Graph;
7use super::node::{NodeId, Op};
8
9mod activation;
10mod attention;
11mod elementwise;
12mod linalg;
13mod norm;
14mod pool;
15mod recurrent;
16mod reduce;
17mod shape;
18
19pub(super) use super::error;
21pub(super) use super::graph;
22pub(super) use super::node;
23
24impl Graph {
25 pub fn backward(&mut self, target: NodeId) -> Result<(), AutogradError> {
27 if !self.node(target)?.value.shape().is_empty() {
28 return Err(AutogradError::NonScalarTarget {
29 shape: self.node(target)?.value.shape().to_vec(),
30 });
31 }
32
33 self.zero_grads();
34 self.node_mut(target)?.grad = Some(Tensor::scalar(1.0));
35
36 for index in (0..=target.0).rev() {
37 let op = self.nodes[index].op.clone();
38 let upstream = match self.nodes[index].grad.take() {
39 Some(grad) => grad,
40 None => continue,
41 };
42
43 match op {
44 Op::Leaf => {
45 self.nodes[index].grad = Some(upstream);
47 }
48 Op::Add(left, right) => {
49 elementwise::add_backward(self, upstream, left, right)?;
50 }
51 Op::Sub(left, right) => {
52 elementwise::sub_backward(self, upstream, left, right)?;
53 }
54 Op::Mul(left, right) => {
55 elementwise::mul_backward(self, upstream, left, right)?;
56 }
57 Op::Div(left, right) => {
58 elementwise::div_backward(self, upstream, left, right)?;
59 }
60 Op::Neg(input) => {
61 elementwise::neg_backward(self, upstream, input)?;
62 }
63 Op::Pow(base, exponent) => {
64 elementwise::pow_backward(self, upstream, index, base, exponent)?;
65 }
66 Op::Abs(input) => {
67 elementwise::abs_backward(self, upstream, input)?;
68 }
69 Op::Clamp {
70 input,
71 min_bits,
72 max_bits,
73 } => {
74 elementwise::clamp_backward(self, upstream, input, min_bits, max_bits)?;
75 }
76
77 Op::Relu(input) => {
79 activation::relu_backward(self, upstream, index, input)?;
80 }
81 Op::LeakyRelu {
82 input,
83 negative_slope,
84 } => {
85 activation::leaky_relu_backward(self, upstream, input, negative_slope)?;
86 }
87 Op::Sigmoid(input) => {
88 activation::sigmoid_backward(self, upstream, index, input)?;
89 }
90 Op::Tanh(input) => {
91 activation::tanh_backward(self, upstream, index, input)?;
92 }
93 Op::Exp(input) => {
94 activation::exp_backward(self, upstream, index, input)?;
95 }
96 Op::Log(input) => {
97 activation::log_backward(self, upstream, input)?;
98 }
99 Op::Sqrt(input) => {
100 activation::sqrt_backward(self, upstream, index, input)?;
101 }
102 Op::Gelu(input) => {
103 activation::gelu_backward(self, upstream, input)?;
104 }
105 Op::Silu(input) => {
106 activation::silu_backward(self, upstream, input)?;
107 }
108 Op::Mish(input) => {
109 activation::mish_backward(self, upstream, input)?;
110 }
111 Op::Softmax(input) => {
112 activation::softmax_backward(self, upstream, index, input)?;
113 }
114 Op::LogSoftmax(input) => {
115 activation::log_softmax_backward(self, upstream, index, input)?;
116 }
117
118 Op::MatMul2D(left, right) => {
120 linalg::matmul2d_backward(self, upstream, left, right)?;
121 }
122 Op::Transpose2D(input) => {
123 linalg::transpose2d_backward(self, upstream, input)?;
124 }
125 Op::Conv2dNhwc {
126 input,
127 weight,
128 bias,
129 stride_h,
130 stride_w,
131 } => {
132 linalg::conv2d_nhwc_backward(
133 self,
134 &upstream,
135 input,
136 weight,
137 bias,
138 stride_h as usize,
139 stride_w as usize,
140 )?;
141 }
142 Op::DepthwiseConv2dNhwc {
143 input,
144 weight,
145 bias,
146 stride_h,
147 stride_w,
148 } => {
149 linalg::depthwise_conv2d_nhwc_backward(
150 self,
151 &upstream,
152 input,
153 weight,
154 bias,
155 stride_h as usize,
156 stride_w as usize,
157 )?;
158 }
159 Op::ConvTranspose2dNhwc {
160 input,
161 weight,
162 bias,
163 stride_h,
164 stride_w,
165 } => {
166 linalg::conv_transpose2d_nhwc_backward(
167 self,
168 &upstream,
169 input,
170 weight,
171 bias,
172 stride_h as usize,
173 stride_w as usize,
174 )?;
175 }
176 Op::Conv1dNlc {
177 input,
178 weight,
179 bias,
180 stride,
181 } => {
182 linalg::conv1d_nlc_backward(
183 self,
184 &upstream,
185 input,
186 weight,
187 bias,
188 stride as usize,
189 )?;
190 }
191 Op::Conv3dNdhwc {
192 input,
193 weight,
194 bias,
195 stride_d,
196 stride_h,
197 stride_w,
198 } => {
199 linalg::conv3d_ndhwc_backward(
200 self,
201 &upstream,
202 input,
203 weight,
204 bias,
205 stride_d as usize,
206 stride_h as usize,
207 stride_w as usize,
208 )?;
209 }
210 Op::DeformableConv2dNhwc {
211 input,
212 weight,
213 offsets,
214 bias,
215 stride,
216 padding,
217 } => {
218 linalg::deformable_conv2d_nhwc_backward(
219 self,
220 &upstream,
221 input,
222 weight,
223 offsets,
224 bias,
225 stride as usize,
226 padding as usize,
227 )?;
228 }
229
230 Op::ReshapeView { input } => {
232 shape::reshape_backward(self, upstream, input)?;
233 }
234 Op::Flatten(input) => {
235 shape::flatten_backward(self, upstream, input)?;
236 }
237 Op::UnsqueezeView { input, axis } => {
238 shape::unsqueeze_backward(self, upstream, input, axis)?;
239 }
240 Op::SqueezeView { input, axis } => {
241 shape::squeeze_backward(self, upstream, input, axis)?;
242 }
243 Op::Cat { ref inputs, axis } => {
244 shape::cat_backward(self, upstream, inputs, axis)?;
245 }
246 Op::Select { input, axis, index } => {
247 shape::select_backward(self, upstream, input, axis, index)?;
248 }
249 Op::Narrow {
250 input,
251 axis,
252 start,
253 len,
254 } => {
255 shape::narrow_backward(self, upstream, input, axis, start, len)?;
256 }
257 Op::Gather { input, axis, index } => {
258 shape::gather_backward(self, upstream, input, axis, index)?;
259 }
260 Op::ScatterAdd {
261 input,
262 axis,
263 index,
264 src,
265 } => {
266 shape::scatter_add_backward(self, upstream, input, axis, index, src)?;
267 }
268 #[allow(clippy::needless_range_loop)]
269 Op::Pad {
270 input,
271 ref pad_before,
272 pad_after: _,
273 } => {
274 shape::pad_backward(self, upstream, input, pad_before)?;
275 }
276 #[allow(clippy::needless_range_loop)]
277 Op::Repeat { input, repeats: _ } => {
278 shape::repeat_backward(self, upstream, input)?;
279 }
280 Op::Scatter {
281 input,
282 indices,
283 src,
284 } => {
285 shape::scatter_backward(self, upstream, input, indices, src)?;
286 }
287 Op::EmbeddingLookup { weight, indices } => {
288 shape::embedding_lookup_backward(self, upstream, weight, indices)?;
289 }
290 Op::PixelShuffle {
291 input,
292 upscale_factor,
293 } => {
294 shape::pixel_shuffle_backward(self, &upstream, input, upscale_factor as usize)?;
295 }
296 Op::UpsampleNearest {
297 input,
298 scale_factor,
299 } => {
300 shape::upsample_nearest_backward(
301 self,
302 &upstream,
303 input,
304 scale_factor as usize,
305 )?;
306 }
307
308 Op::Sum(input) => {
310 reduce::sum_backward(self, upstream, index, input)?;
311 }
312 Op::Mean(input) => {
313 reduce::mean_backward(self, upstream, index, input)?;
314 }
315 Op::SumAxis { input, axis } => {
316 reduce::sum_axis_backward(self, upstream, input, axis)?;
317 }
318 Op::MeanAxis { input, axis } => {
319 reduce::mean_axis_backward(self, upstream, input, axis)?;
320 }
321
322 Op::Rnn {
324 input,
325 w_ih,
326 w_hh,
327 bias,
328 } => {
329 recurrent::rnn_backward(self, &upstream, index, input, w_ih, w_hh, bias)?;
330 }
331 Op::Lstm {
332 input,
333 w_ih,
334 w_hh,
335 bias,
336 } => {
337 recurrent::lstm_backward(self, &upstream, index, input, w_ih, w_hh, bias)?;
338 }
339 Op::Gru {
340 input,
341 w_ih,
342 w_hh,
343 bias_ih,
344 bias_hh,
345 } => {
346 recurrent::gru_backward(
347 self, &upstream, index, input, w_ih, w_hh, bias_ih, bias_hh,
348 )?;
349 }
350
351 Op::BatchNorm2dNhwc {
353 input,
354 gamma,
355 beta,
356 running_mean: _,
357 running_var,
358 epsilon,
359 } => {
360 let eps = f32::from_bits(epsilon);
361 norm::batch_norm2d_nhwc_backward(
362 self,
363 index,
364 &upstream,
365 input,
366 gamma,
367 beta,
368 running_var,
369 eps,
370 )?;
371 }
372 Op::LayerNorm {
373 input,
374 gamma,
375 beta,
376 eps_bits,
377 } => {
378 let eps = f32::from_bits(eps_bits);
379 norm::layer_norm_backward(self, index, &upstream, input, gamma, beta, eps)?;
380 }
381 Op::GroupNorm {
382 input,
383 gamma,
384 beta,
385 num_groups,
386 eps_bits,
387 } => {
388 let eps = f32::from_bits(eps_bits);
389 norm::group_norm_backward(
390 self,
391 index,
392 &upstream,
393 input,
394 gamma,
395 beta,
396 num_groups as usize,
397 eps,
398 )?;
399 }
400 Op::InstanceNormNhwc {
401 input,
402 gamma,
403 beta,
404 eps_bits,
405 } => {
406 let eps = f32::from_bits(eps_bits);
407 norm::instance_norm_nhwc_backward(
408 self, index, &upstream, input, gamma, beta, eps,
409 )?;
410 }
411
412 Op::MaxPool2dNhwc {
414 input,
415 kernel_h: _,
416 kernel_w: _,
417 stride_h: _,
418 stride_w: _,
419 } => {
420 pool::max_pool2d_nhwc_backward(self, upstream, index, input)?;
421 }
422 Op::AvgPool2dNhwc {
423 input,
424 kernel_h,
425 kernel_w,
426 stride_h,
427 stride_w,
428 } => {
429 pool::avg_pool2d_nhwc_backward(
430 self,
431 &upstream,
432 input,
433 kernel_h as usize,
434 kernel_w as usize,
435 stride_h as usize,
436 stride_w as usize,
437 )?;
438 }
439 Op::AdaptiveAvgPool2dNhwc {
440 input,
441 out_h,
442 out_w,
443 } => {
444 pool::adaptive_avg_pool2d_nhwc_backward(
445 self,
446 &upstream,
447 input,
448 out_h as usize,
449 out_w as usize,
450 )?;
451 }
452 Op::AdaptiveMaxPool2dNhwc {
453 input,
454 out_h: _,
455 out_w: _,
456 } => {
457 pool::adaptive_max_pool2d_nhwc_backward(self, upstream, index, input)?;
458 }
459
460 Op::PRelu { input, alpha } => {
462 attention::prelu_backward(self, &upstream, input, alpha)?;
463 }
464 Op::ScaledDotProductAttention { query, key, value } => {
465 attention::scaled_dot_product_attention_backward(
466 self, &upstream, index, query, key, value,
467 )?;
468 }
469 }
470 }
471
472 Ok(())
473 }
474
475 pub fn backward_with_checkpoints(
477 &mut self,
478 target: NodeId,
479 config: &CheckpointConfig,
480 ) -> Result<(), AutogradError> {
481 self.backward(target)?;
482
483 for index in 0..self.nodes.len() {
484 if config.should_checkpoint(index) && !matches!(self.nodes[index].op, Op::Leaf) {
485 self.nodes[index].value = Tensor::scalar(0.0);
486 self.nodes[index].aux = None;
487 }
488 }
489
490 Ok(())
491 }
492
493 pub(crate) fn accumulate_grad(
494 &mut self,
495 node_id: NodeId,
496 contribution: Tensor,
497 ) -> Result<(), AutogradError> {
498 if !self.node(node_id)?.requires_grad {
499 return Ok(());
500 }
501
502 let node = self.node_mut(node_id)?;
503 match &mut node.grad {
504 Some(existing) => {
505 if existing.shape() != contribution.shape() {
506 return Err(AutogradError::InvalidGradientShape {
507 node: node_id.0,
508 expected: existing.shape().to_vec(),
509 got: contribution.shape().to_vec(),
510 });
511 }
512 existing.add_inplace(&contribution);
513 }
514 None => node.grad = Some(contribution),
515 }
516 Ok(())
517 }
518
519 pub(crate) fn dispatch_mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, AutogradError> {
522 if let Some(ref backend) = self.backend {
523 Ok(backend.mul(lhs, rhs)?)
524 } else {
525 Ok(kernel_mul(lhs, rhs)?)
526 }
527 }
528
529 pub(crate) fn dispatch_matmul_2d(
530 &self,
531 lhs: &Tensor,
532 rhs: &Tensor,
533 ) -> Result<Tensor, AutogradError> {
534 if let Some(ref backend) = self.backend {
535 Ok(backend.matmul_2d(lhs, rhs)?)
536 } else {
537 Ok(matmul_2d(lhs, rhs)?)
538 }
539 }
540
541 pub(crate) fn dispatch_transpose_2d(&self, input: &Tensor) -> Result<Tensor, AutogradError> {
542 if let Some(ref backend) = self.backend {
543 Ok(backend.transpose_2d(input)?)
544 } else {
545 transpose_2d(input)
546 }
547 }
548
549 pub(crate) fn dispatch_neg(&self, input: &Tensor) -> Tensor {
550 if let Some(ref backend) = self.backend {
551 backend.neg(input)
552 } else {
553 input.neg()
554 }
555 }
556}
557
558fn transpose_2d(input: &Tensor) -> Result<Tensor, AutogradError> {
559 if input.rank() != 2 {
560 return Err(AutogradError::InvalidRankForOperation {
561 op: "transpose_2d",
562 expected: 2,
563 got: input.rank(),
564 });
565 }
566 let rows = input.shape()[0];
567 let cols = input.shape()[1];
568 let mut data = vec![0.0f32; input.len()];
569 for row in 0..rows {
570 for col in 0..cols {
571 data[col * rows + row] = input.data()[row * cols + col];
572 }
573 }
574 Tensor::from_vec(vec![cols, rows], data).map_err(Into::into)
575}
576
577fn reduce_broadcast_gradient(
578 upstream: &Tensor,
579 target_shape: &[usize],
580) -> Result<Tensor, AutogradError> {
581 if upstream.shape() == target_shape {
582 return Ok(upstream.clone());
583 }
584 if target_shape.len() > upstream.rank() {
585 return Err(AutogradError::BroadcastGradientIncompatible {
586 upstream: upstream.shape().to_vec(),
587 target: target_shape.to_vec(),
588 });
589 }
590
591 let leading_axes = upstream.rank() - target_shape.len();
592 let mut reduced: Option<Tensor> = None;
593 for axis in (0..leading_axes).rev() {
594 reduced = Some(match reduced {
595 Some(r) => r.sum_axis(axis)?,
596 None => upstream.sum_axis(axis)?,
597 });
598 }
599
600 let mut axes_to_reduce = Vec::new();
601 let check_shape = reduced.as_ref().unwrap_or(upstream);
602 for (axis, target_dim) in target_shape.iter().enumerate() {
603 let current_dim = check_shape.shape()[axis];
604 if current_dim == *target_dim {
605 continue;
606 }
607 if *target_dim == 1 && current_dim > 1 {
608 axes_to_reduce.push(axis);
609 continue;
610 }
611 return Err(AutogradError::BroadcastGradientIncompatible {
612 upstream: upstream.shape().to_vec(),
613 target: target_shape.to_vec(),
614 });
615 }
616
617 if !axes_to_reduce.is_empty() && reduced.is_none() {
618 reduced = Some(upstream.clone());
619 }
620 let mut reduced = reduced.unwrap_or_else(|| upstream.clone());
621
622 for axis in axes_to_reduce.into_iter().rev() {
623 reduced = reduced.sum_axis(axis)?;
624 }
625
626 if reduced.shape() != target_shape {
627 reduced = reduced.reshape(target_shape.to_vec())?;
628 }
629 Ok(reduced)
630}