Skip to main content

axonml_nn/layers/
conv.rs

1//! Convolutional Layers - 1D and 2D Convolutions
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/conv.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_autograd::functions::{
21    Conv1dBackward, Conv2dBackward, ConvTranspose2dBackward, GroupedConv2dBackward,
22};
23use axonml_autograd::grad_fn::GradFn;
24use axonml_autograd::no_grad::is_grad_enabled;
25use axonml_tensor::Tensor;
26use rayon::prelude::*;
27
28use crate::init::{kaiming_uniform, zeros};
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32// =============================================================================
33// Conv1d
34// =============================================================================
35
36/// Applies a 1D convolution over an input signal.
37///
38/// # Shape
39/// - Input: (N, C_in, L)
40/// - Output: (N, C_out, L_out)
41///
42/// where L_out = (L + 2*padding - kernel_size) / stride + 1
43pub struct Conv1d {
44    /// Weight tensor of shape (out_channels, in_channels, kernel_size).
45    pub weight: Parameter,
46    /// Bias tensor of shape (out_channels).
47    pub bias: Option<Parameter>,
48    /// Number of input channels.
49    in_channels: usize,
50    /// Number of output channels.
51    out_channels: usize,
52    /// Size of the convolving kernel.
53    kernel_size: usize,
54    /// Stride of the convolution.
55    stride: usize,
56    /// Zero-padding added to both sides.
57    padding: usize,
58}
59
60impl Conv1d {
61    /// Creates a new Conv1d layer.
62    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
63        Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
64    }
65
66    /// Creates a Conv1d layer with all options.
67    pub fn with_options(
68        in_channels: usize,
69        out_channels: usize,
70        kernel_size: usize,
71        stride: usize,
72        padding: usize,
73        bias: bool,
74    ) -> Self {
75        // Initialize weights
76        let fan_in = in_channels * kernel_size;
77        let weight_data = kaiming_uniform(out_channels, fan_in);
78        let weight_reshaped = weight_data
79            .reshape(&[
80                out_channels as isize,
81                in_channels as isize,
82                kernel_size as isize,
83            ])
84            .unwrap();
85        let weight = Parameter::named("weight", weight_reshaped, true);
86
87        let bias_param = if bias {
88            Some(Parameter::named("bias", zeros(&[out_channels]), true))
89        } else {
90            None
91        };
92
93        Self {
94            weight,
95            bias: bias_param,
96            in_channels,
97            out_channels,
98            kernel_size,
99            stride,
100            padding,
101        }
102    }
103}
104
105impl Module for Conv1d {
106    fn forward(&self, input: &Variable) -> Variable {
107        let input_shape = input.shape();
108        let batch_size = input_shape[0];
109        let in_length = input_shape[2];
110
111        let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
112
113        let input_data = input.data();
114        let weight_data = self.weight.data();
115
116        // GPU-resident fast path: reshape [B,C,L] → [B,C,L,1], use Conv2d CUDA pipeline,
117        // then reshape output [B,Cout,Lout,1] → [B,Cout,Lout].
118        #[cfg(feature = "cuda")]
119        if input_data.device().is_gpu() {
120            // Auto-migrate weights to GPU if needed
121            let input_dev = input_data.device();
122            if !weight_data.device().is_gpu() {
123                self.weight.to_device(input_dev);
124                if let Some(ref b) = self.bias {
125                    b.to_device(input_dev);
126                }
127            }
128            let weight_data = self.weight.data();
129
130            // Reshape input [B, Cin, L] → [B, Cin, L, 1]
131            let input_4d = input_data
132                .reshape(&[
133                    batch_size as isize,
134                    self.in_channels as isize,
135                    in_length as isize,
136                    1,
137                ])
138                .unwrap();
139
140            // Reshape weight [Cout, Cin, K] → [Cout, Cin, K, 1]
141            let weight_4d = weight_data
142                .reshape(&[
143                    self.out_channels as isize,
144                    self.in_channels as isize,
145                    self.kernel_size as isize,
146                    1,
147                ])
148                .unwrap();
149
150            let bias_tensor = self.bias.as_ref().map(|b| b.data());
151            let gpu_output = input_4d.conv2d_cuda(
152                &weight_4d,
153                bias_tensor.as_ref(),
154                (self.stride, 1),
155                (self.padding, 0),
156            );
157
158            if let Some(output_4d) = gpu_output {
159                // Reshape output [B, Cout, Lout, 1] → [B, Cout, Lout]
160                let output_tensor = output_4d
161                    .reshape(&[
162                        batch_size as isize,
163                        self.out_channels as isize,
164                        out_length as isize,
165                    ])
166                    .unwrap();
167
168                let requires_grad =
169                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
170                if requires_grad {
171                    let weight_var = self.weight.variable();
172                    let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
173
174                    let grad_fn = GradFn::new(Conv1dBackward::new(
175                        input.grad_fn().cloned(),
176                        weight_var.grad_fn().cloned(),
177                        bias_grad_fn,
178                        input_data,
179                        weight_data,
180                        input_shape,
181                        self.in_channels,
182                        self.out_channels,
183                        self.kernel_size,
184                        self.stride,
185                        self.padding,
186                        self.bias.is_some(),
187                    ));
188                    return Variable::from_operation(output_tensor, grad_fn, true);
189                } else {
190                    return Variable::new(output_tensor, false);
191                }
192            }
193            // Fall through to CPU path if GPU conv failed
194        }
195
196        let input_vec = input_data.to_vec();
197        let weight_vec = weight_data.to_vec();
198
199        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
200
201        for b in 0..batch_size {
202            for oc in 0..self.out_channels {
203                for ol in 0..out_length {
204                    let mut sum = 0.0f32;
205                    let in_start = ol * self.stride;
206
207                    for ic in 0..self.in_channels {
208                        for k in 0..self.kernel_size {
209                            let in_idx = in_start + k;
210                            if in_idx < self.padding || in_idx >= in_length + self.padding {
211                                continue;
212                            }
213                            let actual_idx = in_idx - self.padding;
214
215                            let input_idx =
216                                b * self.in_channels * in_length + ic * in_length + actual_idx;
217                            let weight_idx = oc * self.in_channels * self.kernel_size
218                                + ic * self.kernel_size
219                                + k;
220
221                            sum += input_vec[input_idx] * weight_vec[weight_idx];
222                        }
223                    }
224
225                    if let Some(ref bias) = self.bias {
226                        sum += bias.data().to_vec()[oc];
227                    }
228
229                    let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
230                    output_data[output_idx] = sum;
231                }
232            }
233        }
234
235        let output_tensor =
236            Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length]).unwrap();
237
238        let requires_grad =
239            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
240
241        if requires_grad {
242            let weight_var = self.weight.variable();
243            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
244
245            let grad_fn = GradFn::new(Conv1dBackward::new(
246                input.grad_fn().cloned(),
247                weight_var.grad_fn().cloned(),
248                bias_grad_fn,
249                input_data,
250                weight_data,
251                input_shape,
252                self.in_channels,
253                self.out_channels,
254                self.kernel_size,
255                self.stride,
256                self.padding,
257                self.bias.is_some(),
258            ));
259            Variable::from_operation(output_tensor, grad_fn, true)
260        } else {
261            Variable::new(output_tensor, false)
262        }
263    }
264
265    fn parameters(&self) -> Vec<Parameter> {
266        let mut params = vec![self.weight.clone()];
267        if let Some(ref bias) = self.bias {
268            params.push(bias.clone());
269        }
270        params
271    }
272
273    fn named_parameters(&self) -> HashMap<String, Parameter> {
274        let mut params = HashMap::new();
275        params.insert("weight".to_string(), self.weight.clone());
276        if let Some(ref bias) = self.bias {
277            params.insert("bias".to_string(), bias.clone());
278        }
279        params
280    }
281
282    fn name(&self) -> &'static str {
283        "Conv1d"
284    }
285}
286
287// =============================================================================
288// Conv2d
289// =============================================================================
290
291/// Applies a 2D convolution over an input image.
292///
293/// # Shape
294/// - Input: (N, C_in, H, W)
295/// - Output: (N, C_out, H_out, W_out)
296///
297/// where H_out = (H + 2*padding - kernel_size) / stride + 1
298pub struct Conv2d {
299    /// Weight tensor of shape (out_channels, in_channels, kernel_h, kernel_w).
300    pub weight: Parameter,
301    /// Bias tensor of shape (out_channels).
302    pub bias: Option<Parameter>,
303    /// Number of input channels.
304    in_channels: usize,
305    /// Number of output channels.
306    out_channels: usize,
307    /// Size of the convolving kernel (height, width).
308    kernel_size: (usize, usize),
309    /// Stride of the convolution (height, width).
310    stride: (usize, usize),
311    /// Zero-padding added to both sides (height, width).
312    padding: (usize, usize),
313    /// Number of groups for grouped convolution.
314    groups: usize,
315}
316
317impl Conv2d {
318    /// Creates a new Conv2d layer with square kernel.
319    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
320        Self::with_options(
321            in_channels,
322            out_channels,
323            (kernel_size, kernel_size),
324            (1, 1),
325            (0, 0),
326            true,
327        )
328    }
329
330    /// Creates a Conv2d layer with all options.
331    pub fn with_options(
332        in_channels: usize,
333        out_channels: usize,
334        kernel_size: (usize, usize),
335        stride: (usize, usize),
336        padding: (usize, usize),
337        bias: bool,
338    ) -> Self {
339        Self::with_groups(
340            in_channels,
341            out_channels,
342            kernel_size,
343            stride,
344            padding,
345            bias,
346            1,
347        )
348    }
349
350    /// Creates a Conv2d layer with grouped convolution support.
351    ///
352    /// When `groups == in_channels` and `out_channels == in_channels`, this is
353    /// a depthwise convolution.
354    pub fn with_groups(
355        in_channels: usize,
356        out_channels: usize,
357        kernel_size: (usize, usize),
358        stride: (usize, usize),
359        padding: (usize, usize),
360        bias: bool,
361        groups: usize,
362    ) -> Self {
363        assert!(
364            in_channels % groups == 0,
365            "in_channels must be divisible by groups"
366        );
367        assert!(
368            out_channels % groups == 0,
369            "out_channels must be divisible by groups"
370        );
371
372        let (kh, kw) = kernel_size;
373        let in_channels_per_group = in_channels / groups;
374        let fan_in = in_channels_per_group * kh * kw;
375
376        let weight_data = kaiming_uniform(out_channels, fan_in);
377        let weight_reshaped = weight_data
378            .reshape(&[
379                out_channels as isize,
380                in_channels_per_group as isize,
381                kh as isize,
382                kw as isize,
383            ])
384            .unwrap();
385        let weight = Parameter::named("weight", weight_reshaped, true);
386
387        let bias_param = if bias {
388            Some(Parameter::named("bias", zeros(&[out_channels]), true))
389        } else {
390            None
391        };
392
393        Self {
394            weight,
395            bias: bias_param,
396            in_channels,
397            out_channels,
398            kernel_size,
399            stride,
400            padding,
401            groups,
402        }
403    }
404
405    /// Creates a depthwise convolution (groups = in_channels).
406    pub fn depthwise(channels: usize, kernel_size: usize) -> Self {
407        Self::with_groups(
408            channels,
409            channels,
410            (kernel_size, kernel_size),
411            (1, 1),
412            (kernel_size / 2, kernel_size / 2),
413            true,
414            channels,
415        )
416    }
417}
418
419// =============================================================================
420// im2col + GEMM Conv2d Implementation
421// =============================================================================
422
423/// Unfold input patches into a column matrix (im2col).
424///
425/// Input: `[C_in, H, W]` (one batch element, one group's channels)
426/// Output: `[C_in * kH * kW, out_H * out_W]`
427fn im2col(
428    input: &[f32],
429    channels: usize,
430    height: usize,
431    width: usize,
432    kernel_h: usize,
433    kernel_w: usize,
434    pad_h: usize,
435    pad_w: usize,
436    stride_h: usize,
437    stride_w: usize,
438    out_h: usize,
439    out_w: usize,
440) -> Vec<f32> {
441    let col_h = channels * kernel_h * kernel_w;
442    let col_w = out_h * out_w;
443    let mut col = vec![0.0f32; col_h * col_w];
444    let hw = height * width;
445    let kk = kernel_h * kernel_w;
446    let h_signed = height as isize;
447    let w_signed = width as isize;
448    let pad_h_s = pad_h as isize;
449    let pad_w_s = pad_w as isize;
450
451    // Fused single-pass: iterate linearly over output col matrix
452    // col_row = c * kH * kW + kh_off * kW + kw_off
453    // col_col = oh * out_w + ow
454    for col_row in 0..col_h {
455        let c = col_row / kk;
456        let k_idx = col_row % kk;
457        let kh_off = k_idx / kernel_w;
458        let kw_off = k_idx % kernel_w;
459        let input_c = c * hw;
460        let col_base = col_row * col_w;
461
462        for oh in 0..out_h {
463            let h_in = (oh * stride_h + kh_off) as isize - pad_h_s;
464            if h_in < 0 || h_in >= h_signed {
465                continue;
466            }
467            let input_row = input_c + h_in as usize * width;
468            let col_row_base = col_base + oh * out_w;
469
470            for ow in 0..out_w {
471                let w_in = (ow * stride_w + kw_off) as isize - pad_w_s;
472                if w_in >= 0 && w_in < w_signed {
473                    unsafe {
474                        *col.get_unchecked_mut(col_row_base + ow) =
475                            *input.get_unchecked(input_row + w_in as usize);
476                    }
477                }
478            }
479        }
480    }
481
482    col
483}
484
485/// Conv2d forward using im2col + matmul. Supports groups.
486fn conv2d_im2col(
487    input: &[f32],
488    weight: &[f32],
489    bias: Option<&[f32]>,
490    batch_size: usize,
491    in_channels: usize,
492    in_height: usize,
493    in_width: usize,
494    out_channels: usize,
495    kh: usize,
496    kw: usize,
497    sh: usize,
498    sw: usize,
499    ph: usize,
500    pw: usize,
501    groups: usize,
502) -> Vec<f32> {
503    let out_h = (in_height + 2 * ph - kh) / sh + 1;
504    let out_w = (in_width + 2 * pw - kw) / sw + 1;
505    let in_channels_per_group = in_channels / groups;
506    let out_channels_per_group = out_channels / groups;
507    let col_h = in_channels_per_group * kh * kw;
508    let col_w = out_h * out_w;
509    let spatial = out_h * out_w;
510    let in_spatial = in_height * in_width;
511
512    // Parallel: each batch element produces its own output slice
513    let out_per_batch = out_channels * spatial;
514    let per_batch: Vec<Vec<f32>> = (0..batch_size)
515        .into_par_iter()
516        .map(|b| {
517            let mut batch_out = vec![0.0f32; out_per_batch];
518
519            for g in 0..groups {
520                let ic_start = g * in_channels_per_group;
521                let oc_start = g * out_channels_per_group;
522
523                // Extract input for this batch+group
524                let in_offset = b * in_channels * in_spatial + ic_start * in_spatial;
525                let input_slice = &input[in_offset..in_offset + in_channels_per_group * in_spatial];
526
527                // im2col
528                let col = im2col(
529                    input_slice,
530                    in_channels_per_group,
531                    in_height,
532                    in_width,
533                    kh,
534                    kw,
535                    ph,
536                    pw,
537                    sh,
538                    sw,
539                    out_h,
540                    out_w,
541                );
542
543                // Weight for this group
544                let w_offset = oc_start * in_channels_per_group * kh * kw;
545                let w_size = out_channels_per_group * col_h;
546                let weight_slice = &weight[w_offset..w_offset + w_size];
547
548                // GEMM via Tensor::matmul
549                let w_tensor =
550                    Tensor::from_vec(weight_slice.to_vec(), &[out_channels_per_group, col_h])
551                        .unwrap();
552                let col_tensor = Tensor::from_vec(col, &[col_h, col_w]).unwrap();
553                let result = w_tensor.matmul(&col_tensor).unwrap();
554                let result_vec = result.to_vec();
555
556                // Copy to output with bias
557                let out_offset = oc_start * spatial;
558                for oc_local in 0..out_channels_per_group {
559                    let oc = oc_start + oc_local;
560                    let bias_val = bias.map_or(0.0, |bv| bv[oc]);
561                    let src_start = oc_local * col_w;
562                    let dst_start = out_offset + oc_local * spatial;
563                    if bias_val == 0.0 {
564                        batch_out[dst_start..dst_start + spatial]
565                            .copy_from_slice(&result_vec[src_start..src_start + spatial]);
566                    } else {
567                        for i in 0..spatial {
568                            batch_out[dst_start + i] = result_vec[src_start + i] + bias_val;
569                        }
570                    }
571                }
572            }
573
574            batch_out
575        })
576        .collect();
577
578    // Flatten per-batch results into single output
579    let mut output = Vec::with_capacity(batch_size * out_per_batch);
580    for batch_out in per_batch {
581        output.extend_from_slice(&batch_out);
582    }
583    output
584}
585
586impl Module for Conv2d {
587    fn forward(&self, input: &Variable) -> Variable {
588        let input_shape = input.shape();
589        let batch_size = input_shape[0];
590        let in_height = input_shape[2];
591        let in_width = input_shape[3];
592
593        let (kh, kw) = self.kernel_size;
594        let (sh, sw) = self.stride;
595        let (ph, pw) = self.padding;
596
597        let out_height = (in_height + 2 * ph - kh) / sh + 1;
598        let out_width = (in_width + 2 * pw - kw) / sw + 1;
599
600        let input_data = input.data();
601        let weight_data = self.weight.data();
602
603        // GPU-resident fast path: when input is already on GPU, do everything on GPU
604        // without any CPU↔GPU copies.
605        #[cfg(feature = "cuda")]
606        if input_data.device().is_gpu() {
607            // Auto-migrate weights to GPU if needed (one-time cost, cached via Arc)
608            let input_dev = input_data.device();
609            if !weight_data.device().is_gpu() {
610                self.weight.to_device(input_dev);
611                if let Some(ref b) = self.bias {
612                    b.to_device(input_dev);
613                }
614            }
615            let weight_data = self.weight.data();
616
617            // Try cuDNN first (fastest path), fall back to im2col+GEMM
618            #[cfg(feature = "cudnn")]
619            let cudnn_output = {
620                let bias_tensor = self.bias.as_ref().map(|b| b.data());
621                input_data.conv2d_cudnn(
622                    &weight_data,
623                    bias_tensor.as_ref(),
624                    self.stride,
625                    self.padding,
626                    self.groups,
627                )
628            };
629            #[cfg(not(feature = "cudnn"))]
630            let cudnn_output: Option<axonml_tensor::Tensor<f32>> = None;
631
632            let gpu_output = if cudnn_output.is_some() {
633                cudnn_output
634            } else if self.groups == 1 {
635                // Standard convolution: single im2col + GEMM
636                let bias_tensor = self.bias.as_ref().map(|b| b.data());
637                input_data.conv2d_cuda(
638                    &weight_data,
639                    bias_tensor.as_ref(),
640                    self.stride,
641                    self.padding,
642                )
643            } else {
644                // Grouped convolution: run per-group im2col + GEMM on GPU
645                input_data.conv2d_grouped_cuda(
646                    &weight_data,
647                    self.bias.as_ref().map(|b| b.data()).as_ref(),
648                    self.stride,
649                    self.padding,
650                    self.groups,
651                )
652            };
653
654            if let Some(output_tensor) = gpu_output {
655                let requires_grad =
656                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
657                if requires_grad {
658                    let weight_var = self.weight.variable();
659                    let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
660                    if self.groups == 1 {
661                        let grad_fn = GradFn::new(Conv2dBackward::new(
662                            input.grad_fn().cloned(),
663                            weight_var.grad_fn().cloned(),
664                            bias_grad_fn,
665                            input_data,
666                            weight_data,
667                            input_shape,
668                            self.in_channels,
669                            self.out_channels,
670                            self.kernel_size,
671                            self.stride,
672                            self.padding,
673                            self.bias.is_some(),
674                        ));
675                        return Variable::from_operation(output_tensor, grad_fn, true);
676                    } else {
677                        let grad_fn = GradFn::new(GroupedConv2dBackward::new(
678                            input.grad_fn().cloned(),
679                            weight_var.grad_fn().cloned(),
680                            bias_grad_fn,
681                            input_data,
682                            weight_data,
683                            input_shape,
684                            self.in_channels,
685                            self.out_channels,
686                            self.kernel_size,
687                            self.stride,
688                            self.padding,
689                            self.groups,
690                            self.bias.is_some(),
691                        ));
692                        return Variable::from_operation(output_tensor, grad_fn, true);
693                    }
694                } else {
695                    return Variable::new(output_tensor, false);
696                }
697            }
698            // Fall through to CPU path if GPU conv failed
699        }
700
701        let input_vec = input_data.to_vec();
702        let weight_vec = weight_data.to_vec();
703
704        // Try GPU im2col+GEMM for groups=1 when data is on CPU but GPU is available
705        let conv_flops = self.out_channels * self.in_channels * kh * kw * out_height * out_width;
706        let output_data = if self.groups == 1 && conv_flops >= 500_000 {
707            let bias_vec = self.bias.as_ref().map(|b| b.data().to_vec());
708            let gpu_result = axonml_core::backends::cuda::cuda_conv2d_forward(
709                &input_vec,
710                &weight_vec,
711                bias_vec.as_deref(),
712                batch_size,
713                self.in_channels,
714                in_height,
715                in_width,
716                self.out_channels,
717                kh,
718                kw,
719                sh,
720                sw,
721                ph,
722                pw,
723            );
724
725            if let Some(result) = gpu_result {
726                result
727            } else {
728                conv2d_im2col(
729                    &input_vec,
730                    &weight_vec,
731                    self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
732                    batch_size,
733                    self.in_channels,
734                    in_height,
735                    in_width,
736                    self.out_channels,
737                    kh,
738                    kw,
739                    sh,
740                    sw,
741                    ph,
742                    pw,
743                    self.groups,
744                )
745            }
746        } else {
747            conv2d_im2col(
748                &input_vec,
749                &weight_vec,
750                self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
751                batch_size,
752                self.in_channels,
753                in_height,
754                in_width,
755                self.out_channels,
756                kh,
757                kw,
758                sh,
759                sw,
760                ph,
761                pw,
762                self.groups,
763            )
764        };
765
766        let output_tensor = Tensor::from_vec(
767            output_data,
768            &[batch_size, self.out_channels, out_height, out_width],
769        )
770        .unwrap();
771
772        let requires_grad =
773            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
774
775        if requires_grad && self.groups == 1 {
776            // Full backward pass for standard convolution
777            let weight_var = self.weight.variable();
778            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
779
780            let grad_fn = GradFn::new(Conv2dBackward::new(
781                input.grad_fn().cloned(),
782                weight_var.grad_fn().cloned(),
783                bias_grad_fn,
784                input_data,
785                weight_data,
786                input_shape,
787                self.in_channels,
788                self.out_channels,
789                self.kernel_size,
790                self.stride,
791                self.padding,
792                self.bias.is_some(),
793            ));
794            Variable::from_operation(output_tensor, grad_fn, true)
795        } else if requires_grad {
796            // Grouped convolution backward (depthwise separable, etc.)
797            let weight_var = self.weight.variable();
798            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
799
800            let grad_fn = GradFn::new(GroupedConv2dBackward::new(
801                input.grad_fn().cloned(),
802                weight_var.grad_fn().cloned(),
803                bias_grad_fn,
804                input_data,
805                weight_data,
806                input_shape,
807                self.in_channels,
808                self.out_channels,
809                self.kernel_size,
810                self.stride,
811                self.padding,
812                self.groups,
813                self.bias.is_some(),
814            ));
815            Variable::from_operation(output_tensor, grad_fn, true)
816        } else {
817            Variable::new(output_tensor, false)
818        }
819    }
820
821    fn parameters(&self) -> Vec<Parameter> {
822        let mut params = vec![self.weight.clone()];
823        if let Some(ref bias) = self.bias {
824            params.push(bias.clone());
825        }
826        params
827    }
828
829    fn named_parameters(&self) -> HashMap<String, Parameter> {
830        let mut params = HashMap::new();
831        params.insert("weight".to_string(), self.weight.clone());
832        if let Some(ref bias) = self.bias {
833            params.insert("bias".to_string(), bias.clone());
834        }
835        params
836    }
837
838    fn name(&self) -> &'static str {
839        "Conv2d"
840    }
841}
842
843// =============================================================================
844// ConvTranspose2d
845// =============================================================================
846
847/// Applies a 2D transposed convolution (deconvolution) for upsampling.
848///
849/// # Shape
850/// - Input: (N, C_in, H, W)
851/// - Output: (N, C_out, H_out, W_out)
852///
853/// where H_out = (H - 1) * stride - 2*padding + kernel_size + output_padding
854pub struct ConvTranspose2d {
855    /// Weight tensor of shape (in_channels, out_channels, kernel_h, kernel_w).
856    pub weight: Parameter,
857    /// Bias tensor of shape (out_channels).
858    pub bias: Option<Parameter>,
859    in_channels: usize,
860    out_channels: usize,
861    kernel_size: (usize, usize),
862    stride: (usize, usize),
863    padding: (usize, usize),
864    output_padding: (usize, usize),
865}
866
867impl ConvTranspose2d {
868    /// Creates a new ConvTranspose2d layer with square kernel.
869    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
870        Self::with_options(
871            in_channels,
872            out_channels,
873            (kernel_size, kernel_size),
874            (1, 1),
875            (0, 0),
876            (0, 0),
877            true,
878        )
879    }
880
881    /// Creates a ConvTranspose2d layer with all options.
882    pub fn with_options(
883        in_channels: usize,
884        out_channels: usize,
885        kernel_size: (usize, usize),
886        stride: (usize, usize),
887        padding: (usize, usize),
888        output_padding: (usize, usize),
889        bias: bool,
890    ) -> Self {
891        let (kh, kw) = kernel_size;
892        let fan_in = in_channels * kh * kw;
893
894        let weight_data = kaiming_uniform(out_channels, fan_in);
895        let weight_reshaped = weight_data
896            .reshape(&[
897                in_channels as isize,
898                out_channels as isize,
899                kh as isize,
900                kw as isize,
901            ])
902            .unwrap();
903        let weight = Parameter::named("weight", weight_reshaped, true);
904
905        let bias_param = if bias {
906            Some(Parameter::named("bias", zeros(&[out_channels]), true))
907        } else {
908            None
909        };
910
911        Self {
912            weight,
913            bias: bias_param,
914            in_channels,
915            out_channels,
916            kernel_size,
917            stride,
918            padding,
919            output_padding,
920        }
921    }
922}
923
924impl Module for ConvTranspose2d {
925    fn forward(&self, input: &Variable) -> Variable {
926        let input_shape = input.shape();
927        let batch_size = input_shape[0];
928        let in_h = input_shape[2];
929        let in_w = input_shape[3];
930
931        let (kh, kw) = self.kernel_size;
932        let (sh, sw) = self.stride;
933        let (ph, pw) = self.padding;
934        let (oph, opw) = self.output_padding;
935
936        let out_h = (in_h - 1) * sh - 2 * ph + kh + oph;
937        let out_w = (in_w - 1) * sw - 2 * pw + kw + opw;
938
939        let input_data = input.data();
940        let weight_data = self.weight.data();
941        let input_vec = input_data.to_vec();
942        let weight_vec = weight_data.to_vec();
943
944        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_h * out_w];
945
946        // Transposed convolution: scatter input values through the kernel
947        for b in 0..batch_size {
948            for ic in 0..self.in_channels {
949                for ih in 0..in_h {
950                    for iw in 0..in_w {
951                        let in_idx =
952                            b * self.in_channels * in_h * in_w + ic * in_h * in_w + ih * in_w + iw;
953                        let in_val = input_vec[in_idx];
954
955                        for oc in 0..self.out_channels {
956                            for ki in 0..kh {
957                                for kj in 0..kw {
958                                    let oh_signed = (ih * sh + ki) as isize - ph as isize;
959                                    let ow_signed = (iw * sw + kj) as isize - pw as isize;
960
961                                    if oh_signed >= 0
962                                        && (oh_signed as usize) < out_h
963                                        && ow_signed >= 0
964                                        && (ow_signed as usize) < out_w
965                                    {
966                                        let oh = oh_signed as usize;
967                                        let ow = ow_signed as usize;
968                                        let out_idx = b * self.out_channels * out_h * out_w
969                                            + oc * out_h * out_w
970                                            + oh * out_w
971                                            + ow;
972                                        // weight: (in_channels, out_channels, kh, kw)
973                                        let w_idx = ic * self.out_channels * kh * kw
974                                            + oc * kh * kw
975                                            + ki * kw
976                                            + kj;
977                                        output_data[out_idx] += in_val * weight_vec[w_idx];
978                                    }
979                                }
980                            }
981                        }
982                    }
983                }
984            }
985        }
986
987        // Add bias
988        if let Some(ref bias) = self.bias {
989            let bias_vec = bias.data().to_vec();
990            for b in 0..batch_size {
991                for oc in 0..self.out_channels {
992                    for oh in 0..out_h {
993                        for ow in 0..out_w {
994                            let out_idx = b * self.out_channels * out_h * out_w
995                                + oc * out_h * out_w
996                                + oh * out_w
997                                + ow;
998                            output_data[out_idx] += bias_vec[oc];
999                        }
1000                    }
1001                }
1002            }
1003        }
1004
1005        let output_tensor =
1006            Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_h, out_w]).unwrap();
1007
1008        let requires_grad =
1009            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
1010
1011        if requires_grad {
1012            let weight_var = self.weight.variable();
1013            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
1014
1015            let grad_fn = GradFn::new(ConvTranspose2dBackward::new(
1016                input.grad_fn().cloned(),
1017                weight_var.grad_fn().cloned(),
1018                bias_grad_fn,
1019                input_data,
1020                weight_data,
1021                input_shape,
1022                self.in_channels,
1023                self.out_channels,
1024                self.kernel_size,
1025                self.stride,
1026                self.padding,
1027                self.output_padding,
1028                self.bias.is_some(),
1029            ));
1030            Variable::from_operation(output_tensor, grad_fn, true)
1031        } else {
1032            Variable::new(output_tensor, false)
1033        }
1034    }
1035
1036    fn parameters(&self) -> Vec<Parameter> {
1037        let mut params = vec![self.weight.clone()];
1038        if let Some(ref bias) = self.bias {
1039            params.push(bias.clone());
1040        }
1041        params
1042    }
1043
1044    fn named_parameters(&self) -> HashMap<String, Parameter> {
1045        let mut params = HashMap::new();
1046        params.insert("weight".to_string(), self.weight.clone());
1047        if let Some(ref bias) = self.bias {
1048            params.insert("bias".to_string(), bias.clone());
1049        }
1050        params
1051    }
1052
1053    fn name(&self) -> &'static str {
1054        "ConvTranspose2d"
1055    }
1056}
1057
1058// =============================================================================
1059// Tests
1060// =============================================================================
1061
1062#[cfg(test)]
1063mod tests {
1064    use super::*;
1065
1066    #[test]
1067    fn test_conv1d_creation() {
1068        let conv = Conv1d::new(3, 16, 3);
1069        assert_eq!(conv.in_channels, 3);
1070        assert_eq!(conv.out_channels, 16);
1071        assert_eq!(conv.kernel_size, 3);
1072    }
1073
1074    #[test]
1075    fn test_conv1d_forward() {
1076        let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1077        let input = Variable::new(
1078            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).unwrap(),
1079            false,
1080        );
1081        let output = conv.forward(&input);
1082        assert_eq!(output.shape(), vec![1, 1, 5]);
1083    }
1084
1085    #[test]
1086    fn test_conv1d_backward() {
1087        let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1088        let input = Variable::new(
1089            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).unwrap(),
1090            true,
1091        );
1092        let output = conv.forward(&input);
1093        let loss = output.sum();
1094        loss.backward();
1095
1096        // Input should have gradient (not None)
1097        assert!(
1098            input.grad().is_some(),
1099            "Conv1d: input gradient should flow through backward pass"
1100        );
1101        let grad = input.grad().unwrap();
1102        assert_eq!(grad.shape(), &[1, 1, 5]);
1103    }
1104
1105    #[test]
1106    fn test_conv2d_creation() {
1107        let conv = Conv2d::new(3, 64, 3);
1108        assert_eq!(conv.in_channels, 3);
1109        assert_eq!(conv.out_channels, 64);
1110        assert_eq!(conv.kernel_size, (3, 3));
1111    }
1112
1113    #[test]
1114    fn test_conv2d_forward() {
1115        let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1116        let input = Variable::new(
1117            Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).unwrap(),
1118            false,
1119        );
1120        let output = conv.forward(&input);
1121        assert_eq!(output.shape(), vec![1, 1, 5, 5]);
1122    }
1123
1124    #[test]
1125    fn test_conv2d_backward() {
1126        let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1127        let input = Variable::new(
1128            Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).unwrap(),
1129            true,
1130        );
1131        let output = conv.forward(&input);
1132        let loss = output.sum();
1133        loss.backward();
1134
1135        assert!(
1136            input.grad().is_some(),
1137            "Conv2d: input gradient should flow through backward pass"
1138        );
1139        let grad = input.grad().unwrap();
1140        assert_eq!(grad.shape(), &[1, 1, 5, 5]);
1141
1142        // Weight should also have gradient
1143        let w_grad = conv.weight.grad();
1144        assert!(
1145            w_grad.is_some(),
1146            "Conv2d: weight gradient should be computed"
1147        );
1148    }
1149
1150    #[test]
1151    fn test_conv2d_parameters() {
1152        let conv = Conv2d::new(3, 64, 3);
1153        let params = conv.parameters();
1154        assert_eq!(params.len(), 2); // weight + bias
1155    }
1156
1157    #[test]
1158    fn test_conv2d_grouped() {
1159        // Depthwise: groups = in_channels = out_channels
1160        let conv = Conv2d::depthwise(4, 3);
1161        assert_eq!(conv.groups, 4);
1162        assert_eq!(conv.in_channels, 4);
1163        assert_eq!(conv.out_channels, 4);
1164
1165        let input = Variable::new(
1166            Tensor::from_vec(vec![1.0; 4 * 5 * 5], &[1, 4, 5, 5]).unwrap(),
1167            false,
1168        );
1169        let output = conv.forward(&input);
1170        assert_eq!(output.shape(), vec![1, 4, 5, 5]);
1171    }
1172
1173    #[test]
1174    fn test_conv_transpose2d_forward() {
1175        let conv_t = ConvTranspose2d::with_options(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false);
1176        let input = Variable::new(
1177            Tensor::from_vec(vec![1.0; 4], &[1, 1, 2, 2]).unwrap(),
1178            false,
1179        );
1180        let output = conv_t.forward(&input);
1181        // H_out = (2-1)*2 - 2*1 + 3 + 1 = 4
1182        assert_eq!(output.shape(), vec![1, 1, 4, 4]);
1183    }
1184
1185    #[test]
1186    fn test_conv_transpose2d_backward() {
1187        let conv_t = ConvTranspose2d::new(1, 1, 3);
1188        let input = Variable::new(Tensor::from_vec(vec![1.0; 9], &[1, 1, 3, 3]).unwrap(), true);
1189        let output = conv_t.forward(&input);
1190        let loss = output.sum();
1191        loss.backward();
1192
1193        assert!(
1194            input.grad().is_some(),
1195            "ConvTranspose2d: input gradient should flow through backward"
1196        );
1197    }
1198}