Skip to main content

axonml_nn/layers/
conv.rs

1//! Convolutional Layers - 1D and 2D Convolutions
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/conv.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_autograd::functions::{
21    Conv1dBackward, Conv2dBackward, ConvTranspose2dBackward, GroupedConv2dBackward,
22};
23use axonml_autograd::grad_fn::GradFn;
24use axonml_autograd::no_grad::is_grad_enabled;
25use axonml_tensor::Tensor;
26use rayon::prelude::*;
27
28use crate::init::{kaiming_uniform, zeros};
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32// =============================================================================
33// Conv1d
34// =============================================================================
35
36/// Applies a 1D convolution over an input signal.
37///
38/// # Shape
39/// - Input: (N, C_in, L)
40/// - Output: (N, C_out, L_out)
41///
42/// where L_out = (L + 2*padding - kernel_size) / stride + 1
43pub struct Conv1d {
44    /// Weight tensor of shape (out_channels, in_channels, kernel_size).
45    pub weight: Parameter,
46    /// Bias tensor of shape (out_channels).
47    pub bias: Option<Parameter>,
48    /// Number of input channels.
49    in_channels: usize,
50    /// Number of output channels.
51    out_channels: usize,
52    /// Size of the convolving kernel.
53    kernel_size: usize,
54    /// Stride of the convolution.
55    stride: usize,
56    /// Zero-padding added to both sides.
57    padding: usize,
58}
59
60impl Conv1d {
61    /// Creates a new Conv1d layer.
62    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
63        Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
64    }
65
66    /// Creates a Conv1d layer with all options.
67    pub fn with_options(
68        in_channels: usize,
69        out_channels: usize,
70        kernel_size: usize,
71        stride: usize,
72        padding: usize,
73        bias: bool,
74    ) -> Self {
75        // Initialize weights
76        let fan_in = in_channels * kernel_size;
77        let weight_data = kaiming_uniform(out_channels, fan_in);
78        let weight_reshaped = weight_data
79            .reshape(&[
80                out_channels as isize,
81                in_channels as isize,
82                kernel_size as isize,
83            ])
84            .unwrap();
85        let weight = Parameter::named("weight", weight_reshaped, true);
86
87        let bias_param = if bias {
88            Some(Parameter::named("bias", zeros(&[out_channels]), true))
89        } else {
90            None
91        };
92
93        Self {
94            weight,
95            bias: bias_param,
96            in_channels,
97            out_channels,
98            kernel_size,
99            stride,
100            padding,
101        }
102    }
103}
104
105impl Module for Conv1d {
106    fn forward(&self, input: &Variable) -> Variable {
107        let input_shape = input.shape();
108        let batch_size = input_shape[0];
109        let in_length = input_shape[2];
110
111        let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
112
113        let input_data = input.data();
114        let weight_data = self.weight.data();
115
116        // GPU-resident fast path: reshape [B,C,L] → [B,C,L,1], use Conv2d CUDA pipeline,
117        // then reshape output [B,Cout,Lout,1] → [B,Cout,Lout].
118        #[cfg(feature = "cuda")]
119        if input_data.device().is_gpu() {
120            // Auto-migrate weights to GPU if needed
121            let input_dev = input_data.device();
122            if !weight_data.device().is_gpu() {
123                self.weight.to_device(input_dev);
124                if let Some(ref b) = self.bias {
125                    b.to_device(input_dev);
126                }
127            }
128            let weight_data = self.weight.data();
129
130            // Reshape input [B, Cin, L] → [B, Cin, L, 1]
131            let input_4d = input_data
132                .reshape(&[
133                    batch_size as isize,
134                    self.in_channels as isize,
135                    in_length as isize,
136                    1,
137                ])
138                .unwrap();
139
140            // Reshape weight [Cout, Cin, K] → [Cout, Cin, K, 1]
141            let weight_4d = weight_data
142                .reshape(&[
143                    self.out_channels as isize,
144                    self.in_channels as isize,
145                    self.kernel_size as isize,
146                    1,
147                ])
148                .unwrap();
149
150            let bias_tensor = self.bias.as_ref().map(|b| b.data());
151            let gpu_output = input_4d.conv2d_cuda(
152                &weight_4d,
153                bias_tensor.as_ref(),
154                (self.stride, 1),
155                (self.padding, 0),
156            );
157
158            if let Some(output_4d) = gpu_output {
159                // Reshape output [B, Cout, Lout, 1] → [B, Cout, Lout]
160                let output_tensor = output_4d
161                    .reshape(&[
162                        batch_size as isize,
163                        self.out_channels as isize,
164                        out_length as isize,
165                    ])
166                    .unwrap();
167
168                let requires_grad =
169                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
170                if requires_grad {
171                    let weight_var = self.weight.variable();
172                    let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
173
174                    let grad_fn = GradFn::new(Conv1dBackward::new(
175                        input.grad_fn().cloned(),
176                        weight_var.grad_fn().cloned(),
177                        bias_grad_fn,
178                        input_data,
179                        weight_data,
180                        input_shape,
181                        self.in_channels,
182                        self.out_channels,
183                        self.kernel_size,
184                        self.stride,
185                        self.padding,
186                        self.bias.is_some(),
187                    ));
188                    return Variable::from_operation(output_tensor, grad_fn, true);
189                } else {
190                    return Variable::new(output_tensor, false);
191                }
192            }
193            // Fall through to CPU path if GPU conv failed
194        }
195
196        let input_vec = input_data.to_vec();
197        let weight_vec = weight_data.to_vec();
198
199        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
200
201        for b in 0..batch_size {
202            for oc in 0..self.out_channels {
203                for ol in 0..out_length {
204                    let mut sum = 0.0f32;
205                    let in_start = ol * self.stride;
206
207                    for ic in 0..self.in_channels {
208                        for k in 0..self.kernel_size {
209                            let in_idx = in_start + k;
210                            if in_idx < self.padding || in_idx >= in_length + self.padding {
211                                continue;
212                            }
213                            let actual_idx = in_idx - self.padding;
214
215                            let input_idx =
216                                b * self.in_channels * in_length + ic * in_length + actual_idx;
217                            let weight_idx = oc * self.in_channels * self.kernel_size
218                                + ic * self.kernel_size
219                                + k;
220
221                            sum += input_vec[input_idx] * weight_vec[weight_idx];
222                        }
223                    }
224
225                    if let Some(ref bias) = self.bias {
226                        sum += bias.data().to_vec()[oc];
227                    }
228
229                    let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
230                    output_data[output_idx] = sum;
231                }
232            }
233        }
234
235        let output_tensor =
236            Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length])
237                .expect("tensor creation failed");
238
239        let requires_grad =
240            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
241
242        if requires_grad {
243            let weight_var = self.weight.variable();
244            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
245
246            let grad_fn = GradFn::new(Conv1dBackward::new(
247                input.grad_fn().cloned(),
248                weight_var.grad_fn().cloned(),
249                bias_grad_fn,
250                input_data,
251                weight_data,
252                input_shape,
253                self.in_channels,
254                self.out_channels,
255                self.kernel_size,
256                self.stride,
257                self.padding,
258                self.bias.is_some(),
259            ));
260            Variable::from_operation(output_tensor, grad_fn, true)
261        } else {
262            Variable::new(output_tensor, false)
263        }
264    }
265
266    fn parameters(&self) -> Vec<Parameter> {
267        let mut params = vec![self.weight.clone()];
268        if let Some(ref bias) = self.bias {
269            params.push(bias.clone());
270        }
271        params
272    }
273
274    fn named_parameters(&self) -> HashMap<String, Parameter> {
275        let mut params = HashMap::new();
276        params.insert("weight".to_string(), self.weight.clone());
277        if let Some(ref bias) = self.bias {
278            params.insert("bias".to_string(), bias.clone());
279        }
280        params
281    }
282
283    fn name(&self) -> &'static str {
284        "Conv1d"
285    }
286}
287
288// =============================================================================
289// Conv2d
290// =============================================================================
291
292/// Applies a 2D convolution over an input image.
293///
294/// # Shape
295/// - Input: (N, C_in, H, W)
296/// - Output: (N, C_out, H_out, W_out)
297///
298/// where H_out = (H + 2*padding - kernel_size) / stride + 1
299pub struct Conv2d {
300    /// Weight tensor of shape (out_channels, in_channels, kernel_h, kernel_w).
301    pub weight: Parameter,
302    /// Bias tensor of shape (out_channels).
303    pub bias: Option<Parameter>,
304    /// Number of input channels.
305    in_channels: usize,
306    /// Number of output channels.
307    out_channels: usize,
308    /// Size of the convolving kernel (height, width).
309    kernel_size: (usize, usize),
310    /// Stride of the convolution (height, width).
311    stride: (usize, usize),
312    /// Zero-padding added to both sides (height, width).
313    padding: (usize, usize),
314    /// Number of groups for grouped convolution.
315    groups: usize,
316}
317
318impl Conv2d {
319    /// Creates a new Conv2d layer with square kernel.
320    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
321        Self::with_options(
322            in_channels,
323            out_channels,
324            (kernel_size, kernel_size),
325            (1, 1),
326            (0, 0),
327            true,
328        )
329    }
330
331    /// Creates a Conv2d layer with all options.
332    pub fn with_options(
333        in_channels: usize,
334        out_channels: usize,
335        kernel_size: (usize, usize),
336        stride: (usize, usize),
337        padding: (usize, usize),
338        bias: bool,
339    ) -> Self {
340        Self::with_groups(
341            in_channels,
342            out_channels,
343            kernel_size,
344            stride,
345            padding,
346            bias,
347            1,
348        )
349    }
350
351    /// Creates a Conv2d layer with grouped convolution support.
352    ///
353    /// When `groups == in_channels` and `out_channels == in_channels`, this is
354    /// a depthwise convolution.
355    pub fn with_groups(
356        in_channels: usize,
357        out_channels: usize,
358        kernel_size: (usize, usize),
359        stride: (usize, usize),
360        padding: (usize, usize),
361        bias: bool,
362        groups: usize,
363    ) -> Self {
364        assert!(
365            in_channels % groups == 0,
366            "in_channels must be divisible by groups"
367        );
368        assert!(
369            out_channels % groups == 0,
370            "out_channels must be divisible by groups"
371        );
372
373        let (kh, kw) = kernel_size;
374        let in_channels_per_group = in_channels / groups;
375        let fan_in = in_channels_per_group * kh * kw;
376
377        let weight_data = kaiming_uniform(out_channels, fan_in);
378        let weight_reshaped = weight_data
379            .reshape(&[
380                out_channels as isize,
381                in_channels_per_group as isize,
382                kh as isize,
383                kw as isize,
384            ])
385            .unwrap();
386        let weight = Parameter::named("weight", weight_reshaped, true);
387
388        let bias_param = if bias {
389            Some(Parameter::named("bias", zeros(&[out_channels]), true))
390        } else {
391            None
392        };
393
394        Self {
395            weight,
396            bias: bias_param,
397            in_channels,
398            out_channels,
399            kernel_size,
400            stride,
401            padding,
402            groups,
403        }
404    }
405
406    /// Creates a depthwise convolution (groups = in_channels).
407    pub fn depthwise(channels: usize, kernel_size: usize) -> Self {
408        Self::with_groups(
409            channels,
410            channels,
411            (kernel_size, kernel_size),
412            (1, 1),
413            (kernel_size / 2, kernel_size / 2),
414            true,
415            channels,
416        )
417    }
418}
419
420// =============================================================================
421// im2col + GEMM Conv2d Implementation
422// =============================================================================
423
424/// Unfold input patches into a column matrix (im2col).
425///
426/// Input: `[C_in, H, W]` (one batch element, one group's channels)
427/// Output: `[C_in * kH * kW, out_H * out_W]`
428fn im2col(
429    input: &[f32],
430    channels: usize,
431    height: usize,
432    width: usize,
433    kernel_h: usize,
434    kernel_w: usize,
435    pad_h: usize,
436    pad_w: usize,
437    stride_h: usize,
438    stride_w: usize,
439    out_h: usize,
440    out_w: usize,
441) -> Vec<f32> {
442    let col_h = channels * kernel_h * kernel_w;
443    let col_w = out_h * out_w;
444    let mut col = vec![0.0f32; col_h * col_w];
445    let hw = height * width;
446    let kk = kernel_h * kernel_w;
447    let h_signed = height as isize;
448    let w_signed = width as isize;
449    let pad_h_s = pad_h as isize;
450    let pad_w_s = pad_w as isize;
451
452    // Fused single-pass: iterate linearly over output col matrix
453    // col_row = c * kH * kW + kh_off * kW + kw_off
454    // col_col = oh * out_w + ow
455    for col_row in 0..col_h {
456        let c = col_row / kk;
457        let k_idx = col_row % kk;
458        let kh_off = k_idx / kernel_w;
459        let kw_off = k_idx % kernel_w;
460        let input_c = c * hw;
461        let col_base = col_row * col_w;
462
463        for oh in 0..out_h {
464            let h_in = (oh * stride_h + kh_off) as isize - pad_h_s;
465            if h_in < 0 || h_in >= h_signed {
466                continue;
467            }
468            let input_row = input_c + h_in as usize * width;
469            let col_row_base = col_base + oh * out_w;
470
471            for ow in 0..out_w {
472                let w_in = (ow * stride_w + kw_off) as isize - pad_w_s;
473                if w_in >= 0 && w_in < w_signed {
474                    let col_idx = col_row_base + ow;
475                    let inp_idx = input_row + w_in as usize;
476                    debug_assert!(
477                        col_idx < col.len(),
478                        "im2col fwd col OOB: {col_idx} >= {}",
479                        col.len()
480                    );
481                    debug_assert!(
482                        inp_idx < input.len(),
483                        "im2col fwd input OOB: {inp_idx} >= {}",
484                        input.len()
485                    );
486                    unsafe {
487                        *col.get_unchecked_mut(col_idx) = *input.get_unchecked(inp_idx);
488                    }
489                }
490            }
491        }
492    }
493
494    col
495}
496
497/// Conv2d forward using im2col + matmul. Supports groups.
498fn conv2d_im2col(
499    input: &[f32],
500    weight: &[f32],
501    bias: Option<&[f32]>,
502    batch_size: usize,
503    in_channels: usize,
504    in_height: usize,
505    in_width: usize,
506    out_channels: usize,
507    kh: usize,
508    kw: usize,
509    sh: usize,
510    sw: usize,
511    ph: usize,
512    pw: usize,
513    groups: usize,
514) -> Vec<f32> {
515    let out_h = (in_height + 2 * ph - kh) / sh + 1;
516    let out_w = (in_width + 2 * pw - kw) / sw + 1;
517    let in_channels_per_group = in_channels / groups;
518    let out_channels_per_group = out_channels / groups;
519    let col_h = in_channels_per_group * kh * kw;
520    let col_w = out_h * out_w;
521    let spatial = out_h * out_w;
522    let in_spatial = in_height * in_width;
523
524    // Parallel: each batch element produces its own output slice
525    let out_per_batch = out_channels * spatial;
526    let per_batch: Vec<Vec<f32>> = (0..batch_size)
527        .into_par_iter()
528        .map(|b| {
529            let mut batch_out = vec![0.0f32; out_per_batch];
530
531            for g in 0..groups {
532                let ic_start = g * in_channels_per_group;
533                let oc_start = g * out_channels_per_group;
534
535                // Extract input for this batch+group
536                let in_offset = b * in_channels * in_spatial + ic_start * in_spatial;
537                let input_slice = &input[in_offset..in_offset + in_channels_per_group * in_spatial];
538
539                // im2col
540                let col = im2col(
541                    input_slice,
542                    in_channels_per_group,
543                    in_height,
544                    in_width,
545                    kh,
546                    kw,
547                    ph,
548                    pw,
549                    sh,
550                    sw,
551                    out_h,
552                    out_w,
553                );
554
555                // Weight for this group
556                let w_offset = oc_start * in_channels_per_group * kh * kw;
557                let w_size = out_channels_per_group * col_h;
558                let weight_slice = &weight[w_offset..w_offset + w_size];
559
560                // GEMM via Tensor::matmul
561                let w_tensor =
562                    Tensor::from_vec(weight_slice.to_vec(), &[out_channels_per_group, col_h])
563                        .unwrap();
564                let col_tensor =
565                    Tensor::from_vec(col, &[col_h, col_w]).expect("tensor creation failed");
566                let result = w_tensor.matmul(&col_tensor).expect("matmul failed");
567                let result_vec = result.to_vec();
568
569                // Copy to output with bias
570                let out_offset = oc_start * spatial;
571                for oc_local in 0..out_channels_per_group {
572                    let oc = oc_start + oc_local;
573                    let bias_val = bias.map_or(0.0, |bv| bv[oc]);
574                    let src_start = oc_local * col_w;
575                    let dst_start = out_offset + oc_local * spatial;
576                    if bias_val == 0.0 {
577                        batch_out[dst_start..dst_start + spatial]
578                            .copy_from_slice(&result_vec[src_start..src_start + spatial]);
579                    } else {
580                        for i in 0..spatial {
581                            batch_out[dst_start + i] = result_vec[src_start + i] + bias_val;
582                        }
583                    }
584                }
585            }
586
587            batch_out
588        })
589        .collect();
590
591    // Flatten per-batch results into single output
592    let mut output = Vec::with_capacity(batch_size * out_per_batch);
593    for batch_out in per_batch {
594        output.extend_from_slice(&batch_out);
595    }
596    output
597}
598
599impl Module for Conv2d {
600    fn forward(&self, input: &Variable) -> Variable {
601        let input_shape = input.shape();
602        let batch_size = input_shape[0];
603        let in_height = input_shape[2];
604        let in_width = input_shape[3];
605
606        let (kh, kw) = self.kernel_size;
607        let (sh, sw) = self.stride;
608        let (ph, pw) = self.padding;
609
610        let out_height = (in_height + 2 * ph - kh) / sh + 1;
611        let out_width = (in_width + 2 * pw - kw) / sw + 1;
612
613        let input_data = input.data();
614        let weight_data = self.weight.data();
615
616        // GPU-resident fast path: when input is already on GPU, do everything on GPU
617        // without any CPU↔GPU copies.
618        #[cfg(feature = "cuda")]
619        if input_data.device().is_gpu() {
620            // Auto-migrate weights to GPU if needed (one-time cost, cached via Arc)
621            let input_dev = input_data.device();
622            if !weight_data.device().is_gpu() {
623                self.weight.to_device(input_dev);
624                if let Some(ref b) = self.bias {
625                    b.to_device(input_dev);
626                }
627            }
628            let weight_data = self.weight.data();
629
630            // Try cuDNN first (fastest path), fall back to im2col+GEMM
631            #[cfg(feature = "cudnn")]
632            let cudnn_output = {
633                let bias_tensor = self.bias.as_ref().map(|b| b.data());
634                input_data.conv2d_cudnn(
635                    &weight_data,
636                    bias_tensor.as_ref(),
637                    self.stride,
638                    self.padding,
639                    self.groups,
640                )
641            };
642            #[cfg(not(feature = "cudnn"))]
643            let cudnn_output: Option<axonml_tensor::Tensor<f32>> = None;
644
645            let gpu_output = if cudnn_output.is_some() {
646                cudnn_output
647            } else if self.groups == 1 {
648                // Standard convolution: single im2col + GEMM
649                let bias_tensor = self.bias.as_ref().map(|b| b.data());
650                input_data.conv2d_cuda(
651                    &weight_data,
652                    bias_tensor.as_ref(),
653                    self.stride,
654                    self.padding,
655                )
656            } else {
657                // Grouped convolution: run per-group im2col + GEMM on GPU
658                input_data.conv2d_grouped_cuda(
659                    &weight_data,
660                    self.bias.as_ref().map(|b| b.data()).as_ref(),
661                    self.stride,
662                    self.padding,
663                    self.groups,
664                )
665            };
666
667            if let Some(output_tensor) = gpu_output {
668                let requires_grad =
669                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
670                if requires_grad {
671                    let weight_var = self.weight.variable();
672                    let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
673                    if self.groups == 1 {
674                        let grad_fn = GradFn::new(Conv2dBackward::new(
675                            input.grad_fn().cloned(),
676                            weight_var.grad_fn().cloned(),
677                            bias_grad_fn,
678                            input_data,
679                            weight_data,
680                            input_shape,
681                            self.in_channels,
682                            self.out_channels,
683                            self.kernel_size,
684                            self.stride,
685                            self.padding,
686                            self.bias.is_some(),
687                        ));
688                        return Variable::from_operation(output_tensor, grad_fn, true);
689                    } else {
690                        let grad_fn = GradFn::new(GroupedConv2dBackward::new(
691                            input.grad_fn().cloned(),
692                            weight_var.grad_fn().cloned(),
693                            bias_grad_fn,
694                            input_data,
695                            weight_data,
696                            input_shape,
697                            self.in_channels,
698                            self.out_channels,
699                            self.kernel_size,
700                            self.stride,
701                            self.padding,
702                            self.groups,
703                            self.bias.is_some(),
704                        ));
705                        return Variable::from_operation(output_tensor, grad_fn, true);
706                    }
707                } else {
708                    return Variable::new(output_tensor, false);
709                }
710            }
711            // Fall through to CPU path if GPU conv failed
712        }
713
714        let input_vec = input_data.to_vec();
715        let weight_vec = weight_data.to_vec();
716
717        // Try GPU im2col+GEMM for groups=1 when data is on CPU but GPU is available
718        let conv_flops = self.out_channels * self.in_channels * kh * kw * out_height * out_width;
719        let output_data = if self.groups == 1 && conv_flops >= 500_000 {
720            let bias_vec = self.bias.as_ref().map(|b| b.data().to_vec());
721            let gpu_result = axonml_core::backends::cuda::cuda_conv2d_forward(
722                &input_vec,
723                &weight_vec,
724                bias_vec.as_deref(),
725                batch_size,
726                self.in_channels,
727                in_height,
728                in_width,
729                self.out_channels,
730                kh,
731                kw,
732                sh,
733                sw,
734                ph,
735                pw,
736            );
737
738            if let Some(result) = gpu_result {
739                result
740            } else {
741                conv2d_im2col(
742                    &input_vec,
743                    &weight_vec,
744                    self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
745                    batch_size,
746                    self.in_channels,
747                    in_height,
748                    in_width,
749                    self.out_channels,
750                    kh,
751                    kw,
752                    sh,
753                    sw,
754                    ph,
755                    pw,
756                    self.groups,
757                )
758            }
759        } else {
760            conv2d_im2col(
761                &input_vec,
762                &weight_vec,
763                self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
764                batch_size,
765                self.in_channels,
766                in_height,
767                in_width,
768                self.out_channels,
769                kh,
770                kw,
771                sh,
772                sw,
773                ph,
774                pw,
775                self.groups,
776            )
777        };
778
779        let output_tensor = Tensor::from_vec(
780            output_data,
781            &[batch_size, self.out_channels, out_height, out_width],
782        )
783        .unwrap();
784
785        let requires_grad =
786            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
787
788        if requires_grad && self.groups == 1 {
789            // Full backward pass for standard convolution
790            let weight_var = self.weight.variable();
791            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
792
793            let grad_fn = GradFn::new(Conv2dBackward::new(
794                input.grad_fn().cloned(),
795                weight_var.grad_fn().cloned(),
796                bias_grad_fn,
797                input_data,
798                weight_data,
799                input_shape,
800                self.in_channels,
801                self.out_channels,
802                self.kernel_size,
803                self.stride,
804                self.padding,
805                self.bias.is_some(),
806            ));
807            Variable::from_operation(output_tensor, grad_fn, true)
808        } else if requires_grad {
809            // Grouped convolution backward (depthwise separable, etc.)
810            let weight_var = self.weight.variable();
811            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
812
813            let grad_fn = GradFn::new(GroupedConv2dBackward::new(
814                input.grad_fn().cloned(),
815                weight_var.grad_fn().cloned(),
816                bias_grad_fn,
817                input_data,
818                weight_data,
819                input_shape,
820                self.in_channels,
821                self.out_channels,
822                self.kernel_size,
823                self.stride,
824                self.padding,
825                self.groups,
826                self.bias.is_some(),
827            ));
828            Variable::from_operation(output_tensor, grad_fn, true)
829        } else {
830            Variable::new(output_tensor, false)
831        }
832    }
833
834    fn parameters(&self) -> Vec<Parameter> {
835        let mut params = vec![self.weight.clone()];
836        if let Some(ref bias) = self.bias {
837            params.push(bias.clone());
838        }
839        params
840    }
841
842    fn named_parameters(&self) -> HashMap<String, Parameter> {
843        let mut params = HashMap::new();
844        params.insert("weight".to_string(), self.weight.clone());
845        if let Some(ref bias) = self.bias {
846            params.insert("bias".to_string(), bias.clone());
847        }
848        params
849    }
850
851    fn name(&self) -> &'static str {
852        "Conv2d"
853    }
854}
855
856// =============================================================================
857// ConvTranspose2d
858// =============================================================================
859
860/// Applies a 2D transposed convolution (deconvolution) for upsampling.
861///
862/// # Shape
863/// - Input: (N, C_in, H, W)
864/// - Output: (N, C_out, H_out, W_out)
865///
866/// where H_out = (H - 1) * stride - 2*padding + kernel_size + output_padding
867pub struct ConvTranspose2d {
868    /// Weight tensor of shape (in_channels, out_channels, kernel_h, kernel_w).
869    pub weight: Parameter,
870    /// Bias tensor of shape (out_channels).
871    pub bias: Option<Parameter>,
872    in_channels: usize,
873    out_channels: usize,
874    kernel_size: (usize, usize),
875    stride: (usize, usize),
876    padding: (usize, usize),
877    output_padding: (usize, usize),
878}
879
880impl ConvTranspose2d {
881    /// Creates a new ConvTranspose2d layer with square kernel.
882    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
883        Self::with_options(
884            in_channels,
885            out_channels,
886            (kernel_size, kernel_size),
887            (1, 1),
888            (0, 0),
889            (0, 0),
890            true,
891        )
892    }
893
894    /// Creates a ConvTranspose2d layer with all options.
895    pub fn with_options(
896        in_channels: usize,
897        out_channels: usize,
898        kernel_size: (usize, usize),
899        stride: (usize, usize),
900        padding: (usize, usize),
901        output_padding: (usize, usize),
902        bias: bool,
903    ) -> Self {
904        let (kh, kw) = kernel_size;
905        let fan_in = in_channels * kh * kw;
906
907        let weight_data = kaiming_uniform(out_channels, fan_in);
908        let weight_reshaped = weight_data
909            .reshape(&[
910                in_channels as isize,
911                out_channels as isize,
912                kh as isize,
913                kw as isize,
914            ])
915            .unwrap();
916        let weight = Parameter::named("weight", weight_reshaped, true);
917
918        let bias_param = if bias {
919            Some(Parameter::named("bias", zeros(&[out_channels]), true))
920        } else {
921            None
922        };
923
924        Self {
925            weight,
926            bias: bias_param,
927            in_channels,
928            out_channels,
929            kernel_size,
930            stride,
931            padding,
932            output_padding,
933        }
934    }
935}
936
937impl Module for ConvTranspose2d {
938    fn forward(&self, input: &Variable) -> Variable {
939        let input_shape = input.shape();
940        let batch_size = input_shape[0];
941        let in_h = input_shape[2];
942        let in_w = input_shape[3];
943
944        let (kh, kw) = self.kernel_size;
945        let (sh, sw) = self.stride;
946        let (ph, pw) = self.padding;
947        let (oph, opw) = self.output_padding;
948
949        let out_h = (in_h - 1) * sh - 2 * ph + kh + oph;
950        let out_w = (in_w - 1) * sw - 2 * pw + kw + opw;
951
952        let input_data = input.data();
953        let weight_data = self.weight.data();
954        let input_vec = input_data.to_vec();
955        let weight_vec = weight_data.to_vec();
956
957        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_h * out_w];
958
959        // Transposed convolution: scatter input values through the kernel
960        for b in 0..batch_size {
961            for ic in 0..self.in_channels {
962                for ih in 0..in_h {
963                    for iw in 0..in_w {
964                        let in_idx =
965                            b * self.in_channels * in_h * in_w + ic * in_h * in_w + ih * in_w + iw;
966                        let in_val = input_vec[in_idx];
967
968                        for oc in 0..self.out_channels {
969                            for ki in 0..kh {
970                                for kj in 0..kw {
971                                    let oh_signed = (ih * sh + ki) as isize - ph as isize;
972                                    let ow_signed = (iw * sw + kj) as isize - pw as isize;
973
974                                    if oh_signed >= 0
975                                        && (oh_signed as usize) < out_h
976                                        && ow_signed >= 0
977                                        && (ow_signed as usize) < out_w
978                                    {
979                                        let oh = oh_signed as usize;
980                                        let ow = ow_signed as usize;
981                                        let out_idx = b * self.out_channels * out_h * out_w
982                                            + oc * out_h * out_w
983                                            + oh * out_w
984                                            + ow;
985                                        // weight: (in_channels, out_channels, kh, kw)
986                                        let w_idx = ic * self.out_channels * kh * kw
987                                            + oc * kh * kw
988                                            + ki * kw
989                                            + kj;
990                                        output_data[out_idx] += in_val * weight_vec[w_idx];
991                                    }
992                                }
993                            }
994                        }
995                    }
996                }
997            }
998        }
999
1000        // Add bias
1001        if let Some(ref bias) = self.bias {
1002            let bias_vec = bias.data().to_vec();
1003            for b in 0..batch_size {
1004                for oc in 0..self.out_channels {
1005                    for oh in 0..out_h {
1006                        for ow in 0..out_w {
1007                            let out_idx = b * self.out_channels * out_h * out_w
1008                                + oc * out_h * out_w
1009                                + oh * out_w
1010                                + ow;
1011                            output_data[out_idx] += bias_vec[oc];
1012                        }
1013                    }
1014                }
1015            }
1016        }
1017
1018        let output_tensor =
1019            Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_h, out_w])
1020                .expect("tensor creation failed");
1021
1022        let requires_grad =
1023            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
1024
1025        if requires_grad {
1026            let weight_var = self.weight.variable();
1027            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
1028
1029            let grad_fn = GradFn::new(ConvTranspose2dBackward::new(
1030                input.grad_fn().cloned(),
1031                weight_var.grad_fn().cloned(),
1032                bias_grad_fn,
1033                input_data,
1034                weight_data,
1035                input_shape,
1036                self.in_channels,
1037                self.out_channels,
1038                self.kernel_size,
1039                self.stride,
1040                self.padding,
1041                self.output_padding,
1042                self.bias.is_some(),
1043            ));
1044            Variable::from_operation(output_tensor, grad_fn, true)
1045        } else {
1046            Variable::new(output_tensor, false)
1047        }
1048    }
1049
1050    fn parameters(&self) -> Vec<Parameter> {
1051        let mut params = vec![self.weight.clone()];
1052        if let Some(ref bias) = self.bias {
1053            params.push(bias.clone());
1054        }
1055        params
1056    }
1057
1058    fn named_parameters(&self) -> HashMap<String, Parameter> {
1059        let mut params = HashMap::new();
1060        params.insert("weight".to_string(), self.weight.clone());
1061        if let Some(ref bias) = self.bias {
1062            params.insert("bias".to_string(), bias.clone());
1063        }
1064        params
1065    }
1066
1067    fn name(&self) -> &'static str {
1068        "ConvTranspose2d"
1069    }
1070}
1071
1072// =============================================================================
1073// Tests
1074// =============================================================================
1075
1076#[cfg(test)]
1077mod tests {
1078    use super::*;
1079
1080    #[test]
1081    fn test_conv1d_creation() {
1082        let conv = Conv1d::new(3, 16, 3);
1083        assert_eq!(conv.in_channels, 3);
1084        assert_eq!(conv.out_channels, 16);
1085        assert_eq!(conv.kernel_size, 3);
1086    }
1087
1088    #[test]
1089    fn test_conv1d_forward() {
1090        let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1091        let input = Variable::new(
1092            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5])
1093                .expect("tensor creation failed"),
1094            false,
1095        );
1096        let output = conv.forward(&input);
1097        assert_eq!(output.shape(), vec![1, 1, 5]);
1098    }
1099
1100    #[test]
1101    fn test_conv1d_backward() {
1102        let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1103        let input = Variable::new(
1104            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5])
1105                .expect("tensor creation failed"),
1106            true,
1107        );
1108        let output = conv.forward(&input);
1109        let loss = output.sum();
1110        loss.backward();
1111
1112        // Input should have gradient (not None)
1113        assert!(
1114            input.grad().is_some(),
1115            "Conv1d: input gradient should flow through backward pass"
1116        );
1117        let grad = input.grad().unwrap();
1118        assert_eq!(grad.shape(), &[1, 1, 5]);
1119    }
1120
1121    #[test]
1122    fn test_conv2d_creation() {
1123        let conv = Conv2d::new(3, 64, 3);
1124        assert_eq!(conv.in_channels, 3);
1125        assert_eq!(conv.out_channels, 64);
1126        assert_eq!(conv.kernel_size, (3, 3));
1127    }
1128
1129    #[test]
1130    fn test_conv2d_forward() {
1131        let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1132        let input = Variable::new(
1133            Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1134            false,
1135        );
1136        let output = conv.forward(&input);
1137        assert_eq!(output.shape(), vec![1, 1, 5, 5]);
1138    }
1139
1140    #[test]
1141    fn test_conv2d_backward() {
1142        let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1143        let input = Variable::new(
1144            Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1145            true,
1146        );
1147        let output = conv.forward(&input);
1148        let loss = output.sum();
1149        loss.backward();
1150
1151        assert!(
1152            input.grad().is_some(),
1153            "Conv2d: input gradient should flow through backward pass"
1154        );
1155        let grad = input.grad().unwrap();
1156        assert_eq!(grad.shape(), &[1, 1, 5, 5]);
1157
1158        // Weight should also have gradient
1159        let w_grad = conv.weight.grad();
1160        assert!(
1161            w_grad.is_some(),
1162            "Conv2d: weight gradient should be computed"
1163        );
1164    }
1165
1166    #[test]
1167    fn test_conv2d_parameters() {
1168        let conv = Conv2d::new(3, 64, 3);
1169        let params = conv.parameters();
1170        assert_eq!(params.len(), 2); // weight + bias
1171    }
1172
1173    #[test]
1174    fn test_conv2d_grouped() {
1175        // Depthwise: groups = in_channels = out_channels
1176        let conv = Conv2d::depthwise(4, 3);
1177        assert_eq!(conv.groups, 4);
1178        assert_eq!(conv.in_channels, 4);
1179        assert_eq!(conv.out_channels, 4);
1180
1181        let input = Variable::new(
1182            Tensor::from_vec(vec![1.0; 4 * 5 * 5], &[1, 4, 5, 5]).expect("tensor creation failed"),
1183            false,
1184        );
1185        let output = conv.forward(&input);
1186        assert_eq!(output.shape(), vec![1, 4, 5, 5]);
1187    }
1188
1189    #[test]
1190    fn test_conv_transpose2d_forward() {
1191        let conv_t = ConvTranspose2d::with_options(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false);
1192        let input = Variable::new(
1193            Tensor::from_vec(vec![1.0; 4], &[1, 1, 2, 2]).expect("tensor creation failed"),
1194            false,
1195        );
1196        let output = conv_t.forward(&input);
1197        // H_out = (2-1)*2 - 2*1 + 3 + 1 = 4
1198        assert_eq!(output.shape(), vec![1, 1, 4, 4]);
1199    }
1200
1201    #[test]
1202    fn test_conv_transpose2d_backward() {
1203        let conv_t = ConvTranspose2d::new(1, 1, 3);
1204        let input = Variable::new(
1205            Tensor::from_vec(vec![1.0; 9], &[1, 1, 3, 3]).expect("tensor creation failed"),
1206            true,
1207        );
1208        let output = conv_t.forward(&input);
1209        let loss = output.sum();
1210        loss.backward();
1211
1212        assert!(
1213            input.grad().is_some(),
1214            "ConvTranspose2d: input gradient should flow through backward"
1215        );
1216    }
1217
1218    // =========================================================================
1219    // Conv1d Comprehensive
1220    // =========================================================================
1221
1222    #[test]
1223    fn test_conv1d_with_padding_and_stride() {
1224        let conv = Conv1d::with_options(1, 4, 3, 2, 1, true);
1225        let input = Variable::new(
1226            Tensor::from_vec(vec![1.0; 1 * 1 * 16], &[1, 1, 16]).unwrap(),
1227            true,
1228        );
1229        let output = conv.forward(&input);
1230        // L_out = (16 + 2*1 - 3) / 2 + 1 = 8
1231        assert_eq!(output.shape(), vec![1, 4, 8]);
1232
1233        output.sum().backward();
1234        let grad = input.grad().expect("Conv1d should propagate gradients");
1235        assert_eq!(grad.shape(), &[1, 1, 16]);
1236        assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1237    }
1238
1239    #[test]
1240    fn test_conv1d_multi_channel() {
1241        let conv = Conv1d::new(3, 8, 5); // 3 input channels, 8 output, kernel 5
1242        let input = Variable::new(
1243            Tensor::from_vec(vec![0.5; 2 * 3 * 20], &[2, 3, 20]).unwrap(),
1244            false,
1245        );
1246        let output = conv.forward(&input);
1247        // L_out = (20 - 5) / 1 + 1 = 16 (no padding)
1248        assert_eq!(output.shape(), vec![2, 8, 16]);
1249    }
1250
1251    // =========================================================================
1252    // Conv2d Grouped — Correctness
1253    // =========================================================================
1254
1255    #[test]
1256    fn test_conv2d_grouped_gradient_flow() {
1257        let conv = Conv2d::depthwise(4, 3);
1258        let input = Variable::new(
1259            Tensor::from_vec(vec![1.0; 1 * 4 * 8 * 8], &[1, 4, 8, 8]).unwrap(),
1260            true,
1261        );
1262        let output = conv.forward(&input);
1263        output.sum().backward();
1264
1265        let grad = input
1266            .grad()
1267            .expect("Grouped conv should propagate gradients");
1268        assert_eq!(grad.shape(), &[1, 4, 8, 8]);
1269        assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1270
1271        // Parameters should also get gradients
1272        for p in conv.parameters() {
1273            let g = p.grad().expect("Conv params should have gradients");
1274            assert!(g.to_vec().iter().any(|v| v.abs() > 0.0));
1275        }
1276    }
1277
1278    #[test]
1279    fn test_conv2d_groups_two() {
1280        // 2 groups: 4 input channels split into 2 groups of 2
1281        let conv = Conv2d::with_groups(4, 8, (3, 3), (1, 1), (1, 1), true, 2);
1282        let input = Variable::new(
1283            Tensor::from_vec(vec![1.0; 1 * 4 * 6 * 6], &[1, 4, 6, 6]).unwrap(),
1284            false,
1285        );
1286        let output = conv.forward(&input);
1287        assert_eq!(output.shape(), vec![1, 8, 6, 6]);
1288    }
1289
1290    #[test]
1291    fn test_conv2d_depthwise_separable_pattern() {
1292        // Depthwise separable: depthwise conv + pointwise conv (standard MobileNet pattern)
1293        let dw = Conv2d::depthwise(16, 3); // 16 channels, 3x3 kernel
1294        let pw = Conv2d::with_options(16, 32, (1, 1), (1, 1), (0, 0), true); // pointwise
1295
1296        let input = Variable::new(
1297            Tensor::from_vec(vec![1.0; 1 * 16 * 8 * 8], &[1, 16, 8, 8]).unwrap(),
1298            true,
1299        );
1300        let dw_out = dw.forward(&input);
1301        assert_eq!(dw_out.shape(), vec![1, 16, 8, 8]);
1302
1303        let pw_out = pw.forward(&dw_out);
1304        assert_eq!(pw_out.shape(), vec![1, 32, 8, 8]);
1305
1306        // Full gradient flow through both
1307        pw_out.sum().backward();
1308        let grad = input
1309            .grad()
1310            .expect("Should propagate through depthwise separable");
1311        assert_eq!(grad.shape(), &[1, 16, 8, 8]);
1312    }
1313
1314    // =========================================================================
1315    // ConvTranspose2d Comprehensive
1316    // =========================================================================
1317
1318    #[test]
1319    fn test_conv_transpose2d_upsamples() {
1320        // ConvTranspose2d with stride=2 should roughly double spatial dims
1321        let conv_t = ConvTranspose2d::with_options(1, 1, (4, 4), (2, 2), (1, 1), (0, 0), true);
1322        let input = Variable::new(
1323            Tensor::from_vec(vec![1.0; 1 * 1 * 4 * 4], &[1, 1, 4, 4]).unwrap(),
1324            false,
1325        );
1326        let output = conv_t.forward(&input);
1327        // H_out = (4-1)*2 - 2*1 + 4 + 0 = 8
1328        assert_eq!(output.shape(), vec![1, 1, 8, 8]);
1329    }
1330
1331    #[test]
1332    fn test_conv_transpose2d_gradient_correctness() {
1333        let conv_t = ConvTranspose2d::new(2, 4, 3);
1334        let input = Variable::new(
1335            Tensor::from_vec(vec![0.5; 1 * 2 * 4 * 4], &[1, 2, 4, 4]).unwrap(),
1336            true,
1337        );
1338        let output = conv_t.forward(&input);
1339        output.sum().backward();
1340
1341        let grad = input.grad().unwrap();
1342        assert_eq!(grad.shape(), &[1, 2, 4, 4]);
1343        assert!(grad.to_vec().iter().all(|g| g.is_finite()));
1344        assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1345
1346        // Weight params should also have gradients
1347        for p in conv_t.parameters() {
1348            assert!(p.grad().is_some(), "ConvTranspose2d params need gradients");
1349        }
1350    }
1351
1352    #[test]
1353    fn test_conv_transpose2d_multi_channel() {
1354        let conv_t = ConvTranspose2d::new(8, 16, 3);
1355        let input = Variable::new(
1356            Tensor::from_vec(vec![1.0; 2 * 8 * 4 * 4], &[2, 8, 4, 4]).unwrap(),
1357            false,
1358        );
1359        let output = conv_t.forward(&input);
1360        assert_eq!(output.shape()[0], 2); // batch
1361        assert_eq!(output.shape()[1], 16); // out_channels
1362    }
1363}