Skip to main content

yscv_autograd/backward/
mod.rs

1use yscv_kernels::{matmul_2d, mul as kernel_mul};
2use yscv_tensor::Tensor;
3
4use super::checkpoint::CheckpointConfig;
5use super::error::AutogradError;
6use super::graph::Graph;
7use super::node::{NodeId, Op};
8
9mod activation;
10mod attention;
11mod elementwise;
12mod linalg;
13mod norm;
14mod pool;
15mod recurrent;
16mod reduce;
17mod shape;
18
19// Re-export for sub-modules.
20pub(super) use super::error;
21pub(super) use super::graph;
22pub(super) use super::node;
23
24impl Graph {
25    /// Backpropagates gradients from a scalar target.
26    pub fn backward(&mut self, target: NodeId) -> Result<(), AutogradError> {
27        if !self.node(target)?.value.shape().is_empty() {
28            return Err(AutogradError::NonScalarTarget {
29                shape: self.node(target)?.value.shape().to_vec(),
30            });
31        }
32
33        self.zero_grads();
34        self.node_mut(target)?.grad = Some(Tensor::scalar(1.0));
35
36        for index in (0..=target.0).rev() {
37            let op = self.nodes[index].op.clone();
38            let upstream = match self.nodes[index].grad.take() {
39                Some(grad) => grad,
40                None => continue,
41            };
42
43            match op {
44                Op::Leaf => {
45                    // Restore gradient for leaf nodes — users query these after backward.
46                    self.nodes[index].grad = Some(upstream);
47                }
48                Op::Add(left, right) => {
49                    elementwise::add_backward(self, upstream, left, right)?;
50                }
51                Op::Sub(left, right) => {
52                    elementwise::sub_backward(self, upstream, left, right)?;
53                }
54                Op::Mul(left, right) => {
55                    elementwise::mul_backward(self, upstream, left, right)?;
56                }
57                Op::Div(left, right) => {
58                    elementwise::div_backward(self, upstream, left, right)?;
59                }
60                Op::Neg(input) => {
61                    elementwise::neg_backward(self, upstream, input)?;
62                }
63                Op::Pow(base, exponent) => {
64                    elementwise::pow_backward(self, upstream, index, base, exponent)?;
65                }
66                Op::Abs(input) => {
67                    elementwise::abs_backward(self, upstream, input)?;
68                }
69                Op::Clamp {
70                    input,
71                    min_bits,
72                    max_bits,
73                } => {
74                    elementwise::clamp_backward(self, upstream, input, min_bits, max_bits)?;
75                }
76
77                // Activations
78                Op::Relu(input) => {
79                    activation::relu_backward(self, upstream, index, input)?;
80                }
81                Op::LeakyRelu {
82                    input,
83                    negative_slope,
84                } => {
85                    activation::leaky_relu_backward(self, upstream, input, negative_slope)?;
86                }
87                Op::Sigmoid(input) => {
88                    activation::sigmoid_backward(self, upstream, index, input)?;
89                }
90                Op::Tanh(input) => {
91                    activation::tanh_backward(self, upstream, index, input)?;
92                }
93                Op::Exp(input) => {
94                    activation::exp_backward(self, upstream, index, input)?;
95                }
96                Op::Log(input) => {
97                    activation::log_backward(self, upstream, input)?;
98                }
99                Op::Sqrt(input) => {
100                    activation::sqrt_backward(self, upstream, index, input)?;
101                }
102                Op::Gelu(input) => {
103                    activation::gelu_backward(self, upstream, input)?;
104                }
105                Op::Silu(input) => {
106                    activation::silu_backward(self, upstream, input)?;
107                }
108                Op::Mish(input) => {
109                    activation::mish_backward(self, upstream, input)?;
110                }
111                Op::Softmax(input) => {
112                    activation::softmax_backward(self, upstream, index, input)?;
113                }
114                Op::LogSoftmax(input) => {
115                    activation::log_softmax_backward(self, upstream, index, input)?;
116                }
117
118                // Linear algebra
119                Op::MatMul2D(left, right) => {
120                    linalg::matmul2d_backward(self, upstream, left, right)?;
121                }
122                Op::Transpose2D(input) => {
123                    linalg::transpose2d_backward(self, upstream, input)?;
124                }
125                Op::Conv2dNhwc {
126                    input,
127                    weight,
128                    bias,
129                    stride_h,
130                    stride_w,
131                } => {
132                    linalg::conv2d_nhwc_backward(
133                        self,
134                        &upstream,
135                        input,
136                        weight,
137                        bias,
138                        stride_h as usize,
139                        stride_w as usize,
140                    )?;
141                }
142                Op::DepthwiseConv2dNhwc {
143                    input,
144                    weight,
145                    bias,
146                    stride_h,
147                    stride_w,
148                } => {
149                    linalg::depthwise_conv2d_nhwc_backward(
150                        self,
151                        &upstream,
152                        input,
153                        weight,
154                        bias,
155                        stride_h as usize,
156                        stride_w as usize,
157                    )?;
158                }
159                Op::ConvTranspose2dNhwc {
160                    input,
161                    weight,
162                    bias,
163                    stride_h,
164                    stride_w,
165                } => {
166                    linalg::conv_transpose2d_nhwc_backward(
167                        self,
168                        &upstream,
169                        input,
170                        weight,
171                        bias,
172                        stride_h as usize,
173                        stride_w as usize,
174                    )?;
175                }
176                Op::Conv1dNlc {
177                    input,
178                    weight,
179                    bias,
180                    stride,
181                } => {
182                    linalg::conv1d_nlc_backward(
183                        self,
184                        &upstream,
185                        input,
186                        weight,
187                        bias,
188                        stride as usize,
189                    )?;
190                }
191                Op::Conv3dNdhwc {
192                    input,
193                    weight,
194                    bias,
195                    stride_d,
196                    stride_h,
197                    stride_w,
198                } => {
199                    linalg::conv3d_ndhwc_backward(
200                        self,
201                        &upstream,
202                        input,
203                        weight,
204                        bias,
205                        stride_d as usize,
206                        stride_h as usize,
207                        stride_w as usize,
208                    )?;
209                }
210                Op::DeformableConv2dNhwc {
211                    input,
212                    weight,
213                    offsets,
214                    bias,
215                    stride,
216                    padding,
217                } => {
218                    linalg::deformable_conv2d_nhwc_backward(
219                        self,
220                        &upstream,
221                        input,
222                        weight,
223                        offsets,
224                        bias,
225                        stride as usize,
226                        padding as usize,
227                    )?;
228                }
229
230                // Shape operations
231                Op::ReshapeView { input } => {
232                    shape::reshape_backward(self, upstream, input)?;
233                }
234                Op::Flatten(input) => {
235                    shape::flatten_backward(self, upstream, input)?;
236                }
237                Op::UnsqueezeView { input, axis } => {
238                    shape::unsqueeze_backward(self, upstream, input, axis)?;
239                }
240                Op::SqueezeView { input, axis } => {
241                    shape::squeeze_backward(self, upstream, input, axis)?;
242                }
243                Op::Cat { ref inputs, axis } => {
244                    shape::cat_backward(self, upstream, inputs, axis)?;
245                }
246                Op::Select { input, axis, index } => {
247                    shape::select_backward(self, upstream, input, axis, index)?;
248                }
249                Op::Narrow {
250                    input,
251                    axis,
252                    start,
253                    len,
254                } => {
255                    shape::narrow_backward(self, upstream, input, axis, start, len)?;
256                }
257                Op::Gather { input, axis, index } => {
258                    shape::gather_backward(self, upstream, input, axis, index)?;
259                }
260                Op::ScatterAdd {
261                    input,
262                    axis,
263                    index,
264                    src,
265                } => {
266                    shape::scatter_add_backward(self, upstream, input, axis, index, src)?;
267                }
268                #[allow(clippy::needless_range_loop)]
269                Op::Pad {
270                    input,
271                    ref pad_before,
272                    pad_after: _,
273                } => {
274                    shape::pad_backward(self, upstream, input, pad_before)?;
275                }
276                #[allow(clippy::needless_range_loop)]
277                Op::Repeat { input, repeats: _ } => {
278                    shape::repeat_backward(self, upstream, input)?;
279                }
280                Op::Scatter {
281                    input,
282                    indices,
283                    src,
284                } => {
285                    shape::scatter_backward(self, upstream, input, indices, src)?;
286                }
287                Op::EmbeddingLookup { weight, indices } => {
288                    shape::embedding_lookup_backward(self, upstream, weight, indices)?;
289                }
290                Op::PixelShuffle {
291                    input,
292                    upscale_factor,
293                } => {
294                    shape::pixel_shuffle_backward(self, &upstream, input, upscale_factor as usize)?;
295                }
296                Op::UpsampleNearest {
297                    input,
298                    scale_factor,
299                } => {
300                    shape::upsample_nearest_backward(
301                        self,
302                        &upstream,
303                        input,
304                        scale_factor as usize,
305                    )?;
306                }
307
308                // Reductions
309                Op::Sum(input) => {
310                    reduce::sum_backward(self, upstream, index, input)?;
311                }
312                Op::Mean(input) => {
313                    reduce::mean_backward(self, upstream, index, input)?;
314                }
315                Op::SumAxis { input, axis } => {
316                    reduce::sum_axis_backward(self, upstream, input, axis)?;
317                }
318                Op::MeanAxis { input, axis } => {
319                    reduce::mean_axis_backward(self, upstream, input, axis)?;
320                }
321
322                // Recurrent
323                Op::Rnn {
324                    input,
325                    w_ih,
326                    w_hh,
327                    bias,
328                } => {
329                    recurrent::rnn_backward(self, &upstream, index, input, w_ih, w_hh, bias)?;
330                }
331                Op::Lstm {
332                    input,
333                    w_ih,
334                    w_hh,
335                    bias,
336                } => {
337                    recurrent::lstm_backward(self, &upstream, index, input, w_ih, w_hh, bias)?;
338                }
339                Op::Gru {
340                    input,
341                    w_ih,
342                    w_hh,
343                    bias_ih,
344                    bias_hh,
345                } => {
346                    recurrent::gru_backward(
347                        self, &upstream, index, input, w_ih, w_hh, bias_ih, bias_hh,
348                    )?;
349                }
350
351                // Normalization
352                Op::BatchNorm2dNhwc {
353                    input,
354                    gamma,
355                    beta,
356                    running_mean: _,
357                    running_var,
358                    epsilon,
359                } => {
360                    let eps = f32::from_bits(epsilon);
361                    norm::batch_norm2d_nhwc_backward(
362                        self,
363                        index,
364                        &upstream,
365                        input,
366                        gamma,
367                        beta,
368                        running_var,
369                        eps,
370                    )?;
371                }
372                Op::LayerNorm {
373                    input,
374                    gamma,
375                    beta,
376                    eps_bits,
377                } => {
378                    let eps = f32::from_bits(eps_bits);
379                    norm::layer_norm_backward(self, index, &upstream, input, gamma, beta, eps)?;
380                }
381                Op::GroupNorm {
382                    input,
383                    gamma,
384                    beta,
385                    num_groups,
386                    eps_bits,
387                } => {
388                    let eps = f32::from_bits(eps_bits);
389                    norm::group_norm_backward(
390                        self,
391                        index,
392                        &upstream,
393                        input,
394                        gamma,
395                        beta,
396                        num_groups as usize,
397                        eps,
398                    )?;
399                }
400                Op::InstanceNormNhwc {
401                    input,
402                    gamma,
403                    beta,
404                    eps_bits,
405                } => {
406                    let eps = f32::from_bits(eps_bits);
407                    norm::instance_norm_nhwc_backward(
408                        self, index, &upstream, input, gamma, beta, eps,
409                    )?;
410                }
411
412                // Pooling
413                Op::MaxPool2dNhwc {
414                    input,
415                    kernel_h: _,
416                    kernel_w: _,
417                    stride_h: _,
418                    stride_w: _,
419                } => {
420                    pool::max_pool2d_nhwc_backward(self, upstream, index, input)?;
421                }
422                Op::AvgPool2dNhwc {
423                    input,
424                    kernel_h,
425                    kernel_w,
426                    stride_h,
427                    stride_w,
428                } => {
429                    pool::avg_pool2d_nhwc_backward(
430                        self,
431                        &upstream,
432                        input,
433                        kernel_h as usize,
434                        kernel_w as usize,
435                        stride_h as usize,
436                        stride_w as usize,
437                    )?;
438                }
439                Op::AdaptiveAvgPool2dNhwc {
440                    input,
441                    out_h,
442                    out_w,
443                } => {
444                    pool::adaptive_avg_pool2d_nhwc_backward(
445                        self,
446                        &upstream,
447                        input,
448                        out_h as usize,
449                        out_w as usize,
450                    )?;
451                }
452                Op::AdaptiveMaxPool2dNhwc {
453                    input,
454                    out_h: _,
455                    out_w: _,
456                } => {
457                    pool::adaptive_max_pool2d_nhwc_backward(self, upstream, index, input)?;
458                }
459
460                // Attention & PReLU
461                Op::PRelu { input, alpha } => {
462                    attention::prelu_backward(self, &upstream, input, alpha)?;
463                }
464                Op::ScaledDotProductAttention { query, key, value } => {
465                    attention::scaled_dot_product_attention_backward(
466                        self, &upstream, index, query, key, value,
467                    )?;
468                }
469            }
470        }
471
472        Ok(())
473    }
474
475    /// Backward pass with activation checkpointing.
476    pub fn backward_with_checkpoints(
477        &mut self,
478        target: NodeId,
479        config: &CheckpointConfig,
480    ) -> Result<(), AutogradError> {
481        self.backward(target)?;
482
483        for index in 0..self.nodes.len() {
484            if config.should_checkpoint(index) && !matches!(self.nodes[index].op, Op::Leaf) {
485                self.nodes[index].value = Tensor::scalar(0.0);
486                self.nodes[index].aux = None;
487            }
488        }
489
490        Ok(())
491    }
492
493    pub(crate) fn accumulate_grad(
494        &mut self,
495        node_id: NodeId,
496        contribution: Tensor,
497    ) -> Result<(), AutogradError> {
498        if !self.node(node_id)?.requires_grad {
499            return Ok(());
500        }
501
502        let node = self.node_mut(node_id)?;
503        match &mut node.grad {
504            Some(existing) => {
505                if existing.shape() != contribution.shape() {
506                    return Err(AutogradError::InvalidGradientShape {
507                        node: node_id.0,
508                        expected: existing.shape().to_vec(),
509                        got: contribution.shape().to_vec(),
510                    });
511                }
512                existing.add_inplace(&contribution);
513            }
514            None => node.grad = Some(contribution),
515        }
516        Ok(())
517    }
518
519    // ── Backend dispatch helpers for backward pass ──────────────────
520
521    pub(crate) fn dispatch_mul(&self, lhs: &Tensor, rhs: &Tensor) -> Result<Tensor, AutogradError> {
522        if let Some(ref backend) = self.backend {
523            Ok(backend.mul(lhs, rhs)?)
524        } else {
525            Ok(kernel_mul(lhs, rhs)?)
526        }
527    }
528
529    pub(crate) fn dispatch_matmul_2d(
530        &self,
531        lhs: &Tensor,
532        rhs: &Tensor,
533    ) -> Result<Tensor, AutogradError> {
534        if let Some(ref backend) = self.backend {
535            Ok(backend.matmul_2d(lhs, rhs)?)
536        } else {
537            Ok(matmul_2d(lhs, rhs)?)
538        }
539    }
540
541    pub(crate) fn dispatch_transpose_2d(&self, input: &Tensor) -> Result<Tensor, AutogradError> {
542        if let Some(ref backend) = self.backend {
543            Ok(backend.transpose_2d(input)?)
544        } else {
545            transpose_2d(input)
546        }
547    }
548
549    pub(crate) fn dispatch_neg(&self, input: &Tensor) -> Tensor {
550        if let Some(ref backend) = self.backend {
551            backend.neg(input)
552        } else {
553            input.neg()
554        }
555    }
556}
557
558fn transpose_2d(input: &Tensor) -> Result<Tensor, AutogradError> {
559    if input.rank() != 2 {
560        return Err(AutogradError::InvalidRankForOperation {
561            op: "transpose_2d",
562            expected: 2,
563            got: input.rank(),
564        });
565    }
566    let rows = input.shape()[0];
567    let cols = input.shape()[1];
568    let mut data = vec![0.0f32; input.len()];
569    for row in 0..rows {
570        for col in 0..cols {
571            data[col * rows + row] = input.data()[row * cols + col];
572        }
573    }
574    Tensor::from_vec(vec![cols, rows], data).map_err(Into::into)
575}
576
577fn reduce_broadcast_gradient(
578    upstream: &Tensor,
579    target_shape: &[usize],
580) -> Result<Tensor, AutogradError> {
581    if upstream.shape() == target_shape {
582        return Ok(upstream.clone());
583    }
584    if target_shape.len() > upstream.rank() {
585        return Err(AutogradError::BroadcastGradientIncompatible {
586            upstream: upstream.shape().to_vec(),
587            target: target_shape.to_vec(),
588        });
589    }
590
591    let leading_axes = upstream.rank() - target_shape.len();
592    let mut reduced: Option<Tensor> = None;
593    for axis in (0..leading_axes).rev() {
594        reduced = Some(match reduced {
595            Some(r) => r.sum_axis(axis)?,
596            None => upstream.sum_axis(axis)?,
597        });
598    }
599
600    let mut axes_to_reduce = Vec::new();
601    let check_shape = reduced.as_ref().unwrap_or(upstream);
602    for (axis, target_dim) in target_shape.iter().enumerate() {
603        let current_dim = check_shape.shape()[axis];
604        if current_dim == *target_dim {
605            continue;
606        }
607        if *target_dim == 1 && current_dim > 1 {
608            axes_to_reduce.push(axis);
609            continue;
610        }
611        return Err(AutogradError::BroadcastGradientIncompatible {
612            upstream: upstream.shape().to_vec(),
613            target: target_shape.to_vec(),
614        });
615    }
616
617    if !axes_to_reduce.is_empty() && reduced.is_none() {
618        reduced = Some(upstream.clone());
619    }
620    let mut reduced = reduced.unwrap_or_else(|| upstream.clone());
621
622    for axis in axes_to_reduce.into_iter().rev() {
623        reduced = reduced.sum_axis(axis)?;
624    }
625
626    if reduced.shape() != target_shape {
627        reduced = reduced.reshape(target_shape.to_vec())?;
628    }
629    Ok(reduced)
630}