Skip to main content

axonml_nn/layers/
conv.rs

1//! Convolutional layers — `Conv1d`, `Conv2d`, `ConvTranspose2d`.
2//!
3//! 1360 lines. `Conv1d` (1D temporal), `Conv2d` (2D spatial with im2col
4//! forward + optional cuDNN dispatch via `cudnn_conv2d_forward`),
5//! `ConvTranspose2d` (learnable upsampling). All support padding, stride,
6//! dilation, groups (including depthwise separable), and optional bias.
7//! Kaiming-uniform weight init. Full `Module` trait impl with forward,
8//! parameters, train/eval, to_device.
9//!
10//! # File
11//! `crates/axonml-nn/src/layers/conv.rs`
12//!
13//! # Author
14//! Andrew Jewell Sr. — AutomataNexus LLC
15//! ORCID: 0009-0005-2158-7060
16//!
17//! # Updated
18//! April 14, 2026 11:15 PM EST
19//!
20//! # Disclaimer
21//! Use at own risk. This software is provided "as is", without warranty of any
22//! kind, express or implied. The author and AutomataNexus shall not be held
23//! liable for any damages arising from the use of this software.
24
25use std::collections::HashMap;
26
27use axonml_autograd::Variable;
28use axonml_autograd::functions::{
29    Conv1dBackward, Conv2dBackward, ConvTranspose2dBackward, GroupedConv2dBackward,
30};
31use axonml_autograd::grad_fn::GradFn;
32use axonml_autograd::no_grad::is_grad_enabled;
33use axonml_tensor::Tensor;
34use rayon::prelude::*;
35
36use crate::init::{kaiming_uniform, zeros};
37use crate::module::Module;
38use crate::parameter::Parameter;
39
40// =============================================================================
41// Conv1d
42// =============================================================================
43
44/// Applies a 1D convolution over an input signal.
45///
46/// # Shape
47/// - Input: (N, C_in, L)
48/// - Output: (N, C_out, L_out)
49///
50/// where L_out = (L + 2*padding - kernel_size) / stride + 1
51pub struct Conv1d {
52    /// Weight tensor of shape (out_channels, in_channels, kernel_size).
53    pub weight: Parameter,
54    /// Bias tensor of shape (out_channels).
55    pub bias: Option<Parameter>,
56    /// Number of input channels.
57    in_channels: usize,
58    /// Number of output channels.
59    out_channels: usize,
60    /// Size of the convolving kernel.
61    kernel_size: usize,
62    /// Stride of the convolution.
63    stride: usize,
64    /// Zero-padding added to both sides.
65    padding: usize,
66}
67
68impl Conv1d {
69    /// Creates a new Conv1d layer.
70    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
71        Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
72    }
73
74    /// Creates a Conv1d layer with all options.
75    pub fn with_options(
76        in_channels: usize,
77        out_channels: usize,
78        kernel_size: usize,
79        stride: usize,
80        padding: usize,
81        bias: bool,
82    ) -> Self {
83        // Initialize weights
84        let fan_in = in_channels * kernel_size;
85        let weight_data = kaiming_uniform(out_channels, fan_in);
86        let weight_reshaped = weight_data
87            .reshape(&[
88                out_channels as isize,
89                in_channels as isize,
90                kernel_size as isize,
91            ])
92            .unwrap();
93        let weight = Parameter::named("weight", weight_reshaped, true);
94
95        let bias_param = if bias {
96            Some(Parameter::named("bias", zeros(&[out_channels]), true))
97        } else {
98            None
99        };
100
101        Self {
102            weight,
103            bias: bias_param,
104            in_channels,
105            out_channels,
106            kernel_size,
107            stride,
108            padding,
109        }
110    }
111}
112
113impl Module for Conv1d {
114    fn forward(&self, input: &Variable) -> Variable {
115        let input_shape = input.shape();
116        let batch_size = input_shape[0];
117        let in_length = input_shape[2];
118
119        let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
120
121        let input_data = input.data();
122        let weight_data = self.weight.data();
123
124        // GPU-resident fast path: reshape [B,C,L] → [B,C,L,1], use Conv2d CUDA pipeline,
125        // then reshape output [B,Cout,Lout,1] → [B,Cout,Lout].
126        #[cfg(feature = "cuda")]
127        if input_data.device().is_gpu() {
128            // Auto-migrate weights to GPU if needed
129            let input_dev = input_data.device();
130            if !weight_data.device().is_gpu() {
131                self.weight.to_device(input_dev);
132                if let Some(ref b) = self.bias {
133                    b.to_device(input_dev);
134                }
135            }
136            let weight_data = self.weight.data();
137
138            // Reshape input [B, Cin, L] → [B, Cin, L, 1]
139            let input_4d = input_data
140                .reshape(&[
141                    batch_size as isize,
142                    self.in_channels as isize,
143                    in_length as isize,
144                    1,
145                ])
146                .unwrap();
147
148            // Reshape weight [Cout, Cin, K] → [Cout, Cin, K, 1]
149            let weight_4d = weight_data
150                .reshape(&[
151                    self.out_channels as isize,
152                    self.in_channels as isize,
153                    self.kernel_size as isize,
154                    1,
155                ])
156                .unwrap();
157
158            let bias_tensor = self.bias.as_ref().map(|b| b.data());
159            let gpu_output = input_4d.conv2d_cuda(
160                &weight_4d,
161                bias_tensor.as_ref(),
162                (self.stride, 1),
163                (self.padding, 0),
164            );
165
166            if let Some(output_4d) = gpu_output {
167                // Reshape output [B, Cout, Lout, 1] → [B, Cout, Lout]
168                let output_tensor = output_4d
169                    .reshape(&[
170                        batch_size as isize,
171                        self.out_channels as isize,
172                        out_length as isize,
173                    ])
174                    .unwrap();
175
176                let requires_grad =
177                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
178                if requires_grad {
179                    let weight_var = self.weight.variable();
180                    let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
181
182                    let grad_fn = GradFn::new(Conv1dBackward::new(
183                        input.grad_fn().cloned(),
184                        weight_var.grad_fn().cloned(),
185                        bias_grad_fn,
186                        input_data,
187                        weight_data,
188                        input_shape,
189                        self.in_channels,
190                        self.out_channels,
191                        self.kernel_size,
192                        self.stride,
193                        self.padding,
194                        self.bias.is_some(),
195                    ));
196                    return Variable::from_operation(output_tensor, grad_fn, true);
197                } else {
198                    return Variable::new(output_tensor, false);
199                }
200            }
201            // Fall through to CPU path if GPU conv failed
202        }
203
204        let input_vec = input_data.to_vec();
205        let weight_vec = weight_data.to_vec();
206
207        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
208
209        for b in 0..batch_size {
210            for oc in 0..self.out_channels {
211                for ol in 0..out_length {
212                    let mut sum = 0.0f32;
213                    let in_start = ol * self.stride;
214
215                    for ic in 0..self.in_channels {
216                        for k in 0..self.kernel_size {
217                            let in_idx = in_start + k;
218                            if in_idx < self.padding || in_idx >= in_length + self.padding {
219                                continue;
220                            }
221                            let actual_idx = in_idx - self.padding;
222
223                            let input_idx =
224                                b * self.in_channels * in_length + ic * in_length + actual_idx;
225                            let weight_idx = oc * self.in_channels * self.kernel_size
226                                + ic * self.kernel_size
227                                + k;
228
229                            sum += input_vec[input_idx] * weight_vec[weight_idx];
230                        }
231                    }
232
233                    if let Some(ref bias) = self.bias {
234                        sum += bias.data().to_vec()[oc];
235                    }
236
237                    let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
238                    output_data[output_idx] = sum;
239                }
240            }
241        }
242
243        let output_tensor =
244            Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length])
245                .expect("tensor creation failed");
246
247        let requires_grad =
248            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
249
250        if requires_grad {
251            let weight_var = self.weight.variable();
252            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
253
254            let grad_fn = GradFn::new(Conv1dBackward::new(
255                input.grad_fn().cloned(),
256                weight_var.grad_fn().cloned(),
257                bias_grad_fn,
258                input_data,
259                weight_data,
260                input_shape,
261                self.in_channels,
262                self.out_channels,
263                self.kernel_size,
264                self.stride,
265                self.padding,
266                self.bias.is_some(),
267            ));
268            Variable::from_operation(output_tensor, grad_fn, true)
269        } else {
270            Variable::new(output_tensor, false)
271        }
272    }
273
274    fn parameters(&self) -> Vec<Parameter> {
275        let mut params = vec![self.weight.clone()];
276        if let Some(ref bias) = self.bias {
277            params.push(bias.clone());
278        }
279        params
280    }
281
282    fn named_parameters(&self) -> HashMap<String, Parameter> {
283        let mut params = HashMap::new();
284        params.insert("weight".to_string(), self.weight.clone());
285        if let Some(ref bias) = self.bias {
286            params.insert("bias".to_string(), bias.clone());
287        }
288        params
289    }
290
291    fn name(&self) -> &'static str {
292        "Conv1d"
293    }
294}
295
296// =============================================================================
297// Conv2d
298// =============================================================================
299
300/// Applies a 2D convolution over an input image.
301///
302/// # Shape
303/// - Input: (N, C_in, H, W)
304/// - Output: (N, C_out, H_out, W_out)
305///
306/// where H_out = (H + 2*padding - kernel_size) / stride + 1
307pub struct Conv2d {
308    /// Weight tensor of shape (out_channels, in_channels, kernel_h, kernel_w).
309    pub weight: Parameter,
310    /// Bias tensor of shape (out_channels).
311    pub bias: Option<Parameter>,
312    /// Number of input channels.
313    in_channels: usize,
314    /// Number of output channels.
315    out_channels: usize,
316    /// Size of the convolving kernel (height, width).
317    kernel_size: (usize, usize),
318    /// Stride of the convolution (height, width).
319    stride: (usize, usize),
320    /// Zero-padding added to both sides (height, width).
321    padding: (usize, usize),
322    /// Number of groups for grouped convolution.
323    groups: usize,
324}
325
326impl Conv2d {
327    /// Creates a new Conv2d layer with square kernel.
328    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
329        Self::with_options(
330            in_channels,
331            out_channels,
332            (kernel_size, kernel_size),
333            (1, 1),
334            (0, 0),
335            true,
336        )
337    }
338
339    /// Creates a Conv2d layer with all options.
340    pub fn with_options(
341        in_channels: usize,
342        out_channels: usize,
343        kernel_size: (usize, usize),
344        stride: (usize, usize),
345        padding: (usize, usize),
346        bias: bool,
347    ) -> Self {
348        Self::with_groups(
349            in_channels,
350            out_channels,
351            kernel_size,
352            stride,
353            padding,
354            bias,
355            1,
356        )
357    }
358
359    /// Creates a Conv2d layer with grouped convolution support.
360    ///
361    /// When `groups == in_channels` and `out_channels == in_channels`, this is
362    /// a depthwise convolution.
363    pub fn with_groups(
364        in_channels: usize,
365        out_channels: usize,
366        kernel_size: (usize, usize),
367        stride: (usize, usize),
368        padding: (usize, usize),
369        bias: bool,
370        groups: usize,
371    ) -> Self {
372        assert!(
373            in_channels % groups == 0,
374            "in_channels must be divisible by groups"
375        );
376        assert!(
377            out_channels % groups == 0,
378            "out_channels must be divisible by groups"
379        );
380
381        let (kh, kw) = kernel_size;
382        let in_channels_per_group = in_channels / groups;
383        let fan_in = in_channels_per_group * kh * kw;
384
385        let weight_data = kaiming_uniform(out_channels, fan_in);
386        let weight_reshaped = weight_data
387            .reshape(&[
388                out_channels as isize,
389                in_channels_per_group as isize,
390                kh as isize,
391                kw as isize,
392            ])
393            .unwrap();
394        let weight = Parameter::named("weight", weight_reshaped, true);
395
396        let bias_param = if bias {
397            Some(Parameter::named("bias", zeros(&[out_channels]), true))
398        } else {
399            None
400        };
401
402        Self {
403            weight,
404            bias: bias_param,
405            in_channels,
406            out_channels,
407            kernel_size,
408            stride,
409            padding,
410            groups,
411        }
412    }
413
414    /// Creates a depthwise convolution (groups = in_channels).
415    pub fn depthwise(channels: usize, kernel_size: usize) -> Self {
416        Self::with_groups(
417            channels,
418            channels,
419            (kernel_size, kernel_size),
420            (1, 1),
421            (kernel_size / 2, kernel_size / 2),
422            true,
423            channels,
424        )
425    }
426}
427
428// =============================================================================
429// im2col + GEMM Conv2d Implementation
430// =============================================================================
431
432/// Unfold input patches into a column matrix (im2col).
433///
434/// Input: `[C_in, H, W]` (one batch element, one group's channels)
435/// Output: `[C_in * kH * kW, out_H * out_W]`
436fn im2col(
437    input: &[f32],
438    channels: usize,
439    height: usize,
440    width: usize,
441    kernel_h: usize,
442    kernel_w: usize,
443    pad_h: usize,
444    pad_w: usize,
445    stride_h: usize,
446    stride_w: usize,
447    out_h: usize,
448    out_w: usize,
449) -> Vec<f32> {
450    let col_h = channels * kernel_h * kernel_w;
451    let col_w = out_h * out_w;
452    let mut col = vec![0.0f32; col_h * col_w];
453    let hw = height * width;
454    let kk = kernel_h * kernel_w;
455    let h_signed = height as isize;
456    let w_signed = width as isize;
457    let pad_h_s = pad_h as isize;
458    let pad_w_s = pad_w as isize;
459
460    // Fused single-pass: iterate linearly over output col matrix
461    // col_row = c * kH * kW + kh_off * kW + kw_off
462    // col_col = oh * out_w + ow
463    for col_row in 0..col_h {
464        let c = col_row / kk;
465        let k_idx = col_row % kk;
466        let kh_off = k_idx / kernel_w;
467        let kw_off = k_idx % kernel_w;
468        let input_c = c * hw;
469        let col_base = col_row * col_w;
470
471        for oh in 0..out_h {
472            let h_in = (oh * stride_h + kh_off) as isize - pad_h_s;
473            if h_in < 0 || h_in >= h_signed {
474                continue;
475            }
476            let input_row = input_c + h_in as usize * width;
477            let col_row_base = col_base + oh * out_w;
478
479            for ow in 0..out_w {
480                let w_in = (ow * stride_w + kw_off) as isize - pad_w_s;
481                if w_in >= 0 && w_in < w_signed {
482                    let col_idx = col_row_base + ow;
483                    let inp_idx = input_row + w_in as usize;
484                    debug_assert!(
485                        col_idx < col.len(),
486                        "im2col fwd col OOB: {col_idx} >= {}",
487                        col.len()
488                    );
489                    debug_assert!(
490                        inp_idx < input.len(),
491                        "im2col fwd input OOB: {inp_idx} >= {}",
492                        input.len()
493                    );
494                    unsafe {
495                        *col.get_unchecked_mut(col_idx) = *input.get_unchecked(inp_idx);
496                    }
497                }
498            }
499        }
500    }
501
502    col
503}
504
505/// Conv2d forward using im2col + matmul. Supports groups.
506fn conv2d_im2col(
507    input: &[f32],
508    weight: &[f32],
509    bias: Option<&[f32]>,
510    batch_size: usize,
511    in_channels: usize,
512    in_height: usize,
513    in_width: usize,
514    out_channels: usize,
515    kh: usize,
516    kw: usize,
517    sh: usize,
518    sw: usize,
519    ph: usize,
520    pw: usize,
521    groups: usize,
522) -> Vec<f32> {
523    let out_h = (in_height + 2 * ph - kh) / sh + 1;
524    let out_w = (in_width + 2 * pw - kw) / sw + 1;
525    let in_channels_per_group = in_channels / groups;
526    let out_channels_per_group = out_channels / groups;
527    let col_h = in_channels_per_group * kh * kw;
528    let col_w = out_h * out_w;
529    let spatial = out_h * out_w;
530    let in_spatial = in_height * in_width;
531
532    // Parallel: each batch element produces its own output slice
533    let out_per_batch = out_channels * spatial;
534    let per_batch: Vec<Vec<f32>> = (0..batch_size)
535        .into_par_iter()
536        .map(|b| {
537            let mut batch_out = vec![0.0f32; out_per_batch];
538
539            for g in 0..groups {
540                let ic_start = g * in_channels_per_group;
541                let oc_start = g * out_channels_per_group;
542
543                // Extract input for this batch+group
544                let in_offset = b * in_channels * in_spatial + ic_start * in_spatial;
545                let input_slice = &input[in_offset..in_offset + in_channels_per_group * in_spatial];
546
547                // im2col
548                let col = im2col(
549                    input_slice,
550                    in_channels_per_group,
551                    in_height,
552                    in_width,
553                    kh,
554                    kw,
555                    ph,
556                    pw,
557                    sh,
558                    sw,
559                    out_h,
560                    out_w,
561                );
562
563                // Weight for this group
564                let w_offset = oc_start * in_channels_per_group * kh * kw;
565                let w_size = out_channels_per_group * col_h;
566                let weight_slice = &weight[w_offset..w_offset + w_size];
567
568                // GEMM via Tensor::matmul
569                let w_tensor =
570                    Tensor::from_vec(weight_slice.to_vec(), &[out_channels_per_group, col_h])
571                        .unwrap();
572                let col_tensor =
573                    Tensor::from_vec(col, &[col_h, col_w]).expect("tensor creation failed");
574                let result = w_tensor.matmul(&col_tensor).expect("matmul failed");
575                let result_vec = result.to_vec();
576
577                // Copy to output with bias
578                let out_offset = oc_start * spatial;
579                for oc_local in 0..out_channels_per_group {
580                    let oc = oc_start + oc_local;
581                    let bias_val = bias.map_or(0.0, |bv| bv[oc]);
582                    let src_start = oc_local * col_w;
583                    let dst_start = out_offset + oc_local * spatial;
584                    if bias_val == 0.0 {
585                        batch_out[dst_start..dst_start + spatial]
586                            .copy_from_slice(&result_vec[src_start..src_start + spatial]);
587                    } else {
588                        for i in 0..spatial {
589                            batch_out[dst_start + i] = result_vec[src_start + i] + bias_val;
590                        }
591                    }
592                }
593            }
594
595            batch_out
596        })
597        .collect();
598
599    // Flatten per-batch results into single output
600    let mut output = Vec::with_capacity(batch_size * out_per_batch);
601    for batch_out in per_batch {
602        output.extend_from_slice(&batch_out);
603    }
604    output
605}
606
607impl Module for Conv2d {
608    fn forward(&self, input: &Variable) -> Variable {
609        let input_shape = input.shape();
610        let batch_size = input_shape[0];
611        let in_height = input_shape[2];
612        let in_width = input_shape[3];
613
614        let (kh, kw) = self.kernel_size;
615        let (sh, sw) = self.stride;
616        let (ph, pw) = self.padding;
617
618        let out_height = (in_height + 2 * ph - kh) / sh + 1;
619        let out_width = (in_width + 2 * pw - kw) / sw + 1;
620
621        let input_data = input.data();
622        let weight_data = self.weight.data();
623
624        // GPU-resident fast path: when input is already on GPU, do everything on GPU
625        // without any CPU↔GPU copies.
626        #[cfg(feature = "cuda")]
627        if input_data.device().is_gpu() {
628            // Auto-migrate weights to GPU if needed (one-time cost, cached via Arc)
629            let input_dev = input_data.device();
630            if !weight_data.device().is_gpu() {
631                self.weight.to_device(input_dev);
632                if let Some(ref b) = self.bias {
633                    b.to_device(input_dev);
634                }
635            }
636            let weight_data = self.weight.data();
637
638            // Try cuDNN first (fastest path), fall back to im2col+GEMM
639            #[cfg(feature = "cudnn")]
640            let cudnn_output = {
641                let bias_tensor = self.bias.as_ref().map(|b| b.data());
642                input_data.conv2d_cudnn(
643                    &weight_data,
644                    bias_tensor.as_ref(),
645                    self.stride,
646                    self.padding,
647                    self.groups,
648                )
649            };
650            #[cfg(not(feature = "cudnn"))]
651            let cudnn_output: Option<axonml_tensor::Tensor<f32>> = None;
652
653            let gpu_output = if cudnn_output.is_some() {
654                cudnn_output
655            } else if self.groups == 1 {
656                // Standard convolution: single im2col + GEMM
657                let bias_tensor = self.bias.as_ref().map(|b| b.data());
658                input_data.conv2d_cuda(
659                    &weight_data,
660                    bias_tensor.as_ref(),
661                    self.stride,
662                    self.padding,
663                )
664            } else {
665                // Grouped convolution: run per-group im2col + GEMM on GPU
666                input_data.conv2d_grouped_cuda(
667                    &weight_data,
668                    self.bias.as_ref().map(|b| b.data()).as_ref(),
669                    self.stride,
670                    self.padding,
671                    self.groups,
672                )
673            };
674
675            if let Some(output_tensor) = gpu_output {
676                let requires_grad =
677                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
678                if requires_grad {
679                    let weight_var = self.weight.variable();
680                    let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
681                    if self.groups == 1 {
682                        let grad_fn = GradFn::new(Conv2dBackward::new(
683                            input.grad_fn().cloned(),
684                            weight_var.grad_fn().cloned(),
685                            bias_grad_fn,
686                            input_data,
687                            weight_data,
688                            input_shape,
689                            self.in_channels,
690                            self.out_channels,
691                            self.kernel_size,
692                            self.stride,
693                            self.padding,
694                            self.bias.is_some(),
695                        ));
696                        return Variable::from_operation(output_tensor, grad_fn, true);
697                    } else {
698                        let grad_fn = GradFn::new(GroupedConv2dBackward::new(
699                            input.grad_fn().cloned(),
700                            weight_var.grad_fn().cloned(),
701                            bias_grad_fn,
702                            input_data,
703                            weight_data,
704                            input_shape,
705                            self.in_channels,
706                            self.out_channels,
707                            self.kernel_size,
708                            self.stride,
709                            self.padding,
710                            self.groups,
711                            self.bias.is_some(),
712                        ));
713                        return Variable::from_operation(output_tensor, grad_fn, true);
714                    }
715                } else {
716                    return Variable::new(output_tensor, false);
717                }
718            }
719            // Fall through to CPU path if GPU conv failed
720        }
721
722        let input_vec = input_data.to_vec();
723        let weight_vec = weight_data.to_vec();
724
725        // Try GPU im2col+GEMM for groups=1 when data is on CPU but GPU is available
726        let conv_flops = self.out_channels * self.in_channels * kh * kw * out_height * out_width;
727        let output_data = if self.groups == 1 && conv_flops >= 500_000 {
728            let bias_vec = self.bias.as_ref().map(|b| b.data().to_vec());
729            let gpu_result = axonml_core::backends::cuda::cuda_conv2d_forward(
730                &input_vec,
731                &weight_vec,
732                bias_vec.as_deref(),
733                batch_size,
734                self.in_channels,
735                in_height,
736                in_width,
737                self.out_channels,
738                kh,
739                kw,
740                sh,
741                sw,
742                ph,
743                pw,
744            );
745
746            if let Some(result) = gpu_result {
747                result
748            } else {
749                conv2d_im2col(
750                    &input_vec,
751                    &weight_vec,
752                    self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
753                    batch_size,
754                    self.in_channels,
755                    in_height,
756                    in_width,
757                    self.out_channels,
758                    kh,
759                    kw,
760                    sh,
761                    sw,
762                    ph,
763                    pw,
764                    self.groups,
765                )
766            }
767        } else {
768            conv2d_im2col(
769                &input_vec,
770                &weight_vec,
771                self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
772                batch_size,
773                self.in_channels,
774                in_height,
775                in_width,
776                self.out_channels,
777                kh,
778                kw,
779                sh,
780                sw,
781                ph,
782                pw,
783                self.groups,
784            )
785        };
786
787        let output_tensor = Tensor::from_vec(
788            output_data,
789            &[batch_size, self.out_channels, out_height, out_width],
790        )
791        .unwrap();
792
793        let requires_grad =
794            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
795
796        if requires_grad && self.groups == 1 {
797            // Full backward pass for standard convolution
798            let weight_var = self.weight.variable();
799            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
800
801            let grad_fn = GradFn::new(Conv2dBackward::new(
802                input.grad_fn().cloned(),
803                weight_var.grad_fn().cloned(),
804                bias_grad_fn,
805                input_data,
806                weight_data,
807                input_shape,
808                self.in_channels,
809                self.out_channels,
810                self.kernel_size,
811                self.stride,
812                self.padding,
813                self.bias.is_some(),
814            ));
815            Variable::from_operation(output_tensor, grad_fn, true)
816        } else if requires_grad {
817            // Grouped convolution backward (depthwise separable, etc.)
818            let weight_var = self.weight.variable();
819            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
820
821            let grad_fn = GradFn::new(GroupedConv2dBackward::new(
822                input.grad_fn().cloned(),
823                weight_var.grad_fn().cloned(),
824                bias_grad_fn,
825                input_data,
826                weight_data,
827                input_shape,
828                self.in_channels,
829                self.out_channels,
830                self.kernel_size,
831                self.stride,
832                self.padding,
833                self.groups,
834                self.bias.is_some(),
835            ));
836            Variable::from_operation(output_tensor, grad_fn, true)
837        } else {
838            Variable::new(output_tensor, false)
839        }
840    }
841
842    fn parameters(&self) -> Vec<Parameter> {
843        let mut params = vec![self.weight.clone()];
844        if let Some(ref bias) = self.bias {
845            params.push(bias.clone());
846        }
847        params
848    }
849
850    fn named_parameters(&self) -> HashMap<String, Parameter> {
851        let mut params = HashMap::new();
852        params.insert("weight".to_string(), self.weight.clone());
853        if let Some(ref bias) = self.bias {
854            params.insert("bias".to_string(), bias.clone());
855        }
856        params
857    }
858
859    fn name(&self) -> &'static str {
860        "Conv2d"
861    }
862}
863
864// =============================================================================
865// ConvTranspose2d
866// =============================================================================
867
868/// Applies a 2D transposed convolution (deconvolution) for upsampling.
869///
870/// # Shape
871/// - Input: (N, C_in, H, W)
872/// - Output: (N, C_out, H_out, W_out)
873///
874/// where H_out = (H - 1) * stride - 2*padding + kernel_size + output_padding
875pub struct ConvTranspose2d {
876    /// Weight tensor of shape (in_channels, out_channels, kernel_h, kernel_w).
877    pub weight: Parameter,
878    /// Bias tensor of shape (out_channels).
879    pub bias: Option<Parameter>,
880    in_channels: usize,
881    out_channels: usize,
882    kernel_size: (usize, usize),
883    stride: (usize, usize),
884    padding: (usize, usize),
885    output_padding: (usize, usize),
886}
887
888impl ConvTranspose2d {
889    /// Creates a new ConvTranspose2d layer with square kernel.
890    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
891        Self::with_options(
892            in_channels,
893            out_channels,
894            (kernel_size, kernel_size),
895            (1, 1),
896            (0, 0),
897            (0, 0),
898            true,
899        )
900    }
901
902    /// Creates a ConvTranspose2d layer with all options.
903    pub fn with_options(
904        in_channels: usize,
905        out_channels: usize,
906        kernel_size: (usize, usize),
907        stride: (usize, usize),
908        padding: (usize, usize),
909        output_padding: (usize, usize),
910        bias: bool,
911    ) -> Self {
912        let (kh, kw) = kernel_size;
913        let fan_in = in_channels * kh * kw;
914
915        let weight_data = kaiming_uniform(out_channels, fan_in);
916        let weight_reshaped = weight_data
917            .reshape(&[
918                in_channels as isize,
919                out_channels as isize,
920                kh as isize,
921                kw as isize,
922            ])
923            .unwrap();
924        let weight = Parameter::named("weight", weight_reshaped, true);
925
926        let bias_param = if bias {
927            Some(Parameter::named("bias", zeros(&[out_channels]), true))
928        } else {
929            None
930        };
931
932        Self {
933            weight,
934            bias: bias_param,
935            in_channels,
936            out_channels,
937            kernel_size,
938            stride,
939            padding,
940            output_padding,
941        }
942    }
943}
944
945impl Module for ConvTranspose2d {
946    fn forward(&self, input: &Variable) -> Variable {
947        let input_shape = input.shape();
948        let batch_size = input_shape[0];
949        let in_h = input_shape[2];
950        let in_w = input_shape[3];
951
952        let (kh, kw) = self.kernel_size;
953        let (sh, sw) = self.stride;
954        let (ph, pw) = self.padding;
955        let (oph, opw) = self.output_padding;
956
957        let out_h = (in_h - 1) * sh - 2 * ph + kh + oph;
958        let out_w = (in_w - 1) * sw - 2 * pw + kw + opw;
959
960        let input_data = input.data();
961        let weight_data = self.weight.data();
962        let input_vec = input_data.to_vec();
963        let weight_vec = weight_data.to_vec();
964
965        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_h * out_w];
966
967        // Transposed convolution: scatter input values through the kernel
968        for b in 0..batch_size {
969            for ic in 0..self.in_channels {
970                for ih in 0..in_h {
971                    for iw in 0..in_w {
972                        let in_idx =
973                            b * self.in_channels * in_h * in_w + ic * in_h * in_w + ih * in_w + iw;
974                        let in_val = input_vec[in_idx];
975
976                        for oc in 0..self.out_channels {
977                            for ki in 0..kh {
978                                for kj in 0..kw {
979                                    let oh_signed = (ih * sh + ki) as isize - ph as isize;
980                                    let ow_signed = (iw * sw + kj) as isize - pw as isize;
981
982                                    if oh_signed >= 0
983                                        && (oh_signed as usize) < out_h
984                                        && ow_signed >= 0
985                                        && (ow_signed as usize) < out_w
986                                    {
987                                        let oh = oh_signed as usize;
988                                        let ow = ow_signed as usize;
989                                        let out_idx = b * self.out_channels * out_h * out_w
990                                            + oc * out_h * out_w
991                                            + oh * out_w
992                                            + ow;
993                                        // weight: (in_channels, out_channels, kh, kw)
994                                        let w_idx = ic * self.out_channels * kh * kw
995                                            + oc * kh * kw
996                                            + ki * kw
997                                            + kj;
998                                        output_data[out_idx] += in_val * weight_vec[w_idx];
999                                    }
1000                                }
1001                            }
1002                        }
1003                    }
1004                }
1005            }
1006        }
1007
1008        // Add bias
1009        if let Some(ref bias) = self.bias {
1010            let bias_vec = bias.data().to_vec();
1011            for b in 0..batch_size {
1012                for oc in 0..self.out_channels {
1013                    for oh in 0..out_h {
1014                        for ow in 0..out_w {
1015                            let out_idx = b * self.out_channels * out_h * out_w
1016                                + oc * out_h * out_w
1017                                + oh * out_w
1018                                + ow;
1019                            output_data[out_idx] += bias_vec[oc];
1020                        }
1021                    }
1022                }
1023            }
1024        }
1025
1026        let output_tensor =
1027            Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_h, out_w])
1028                .expect("tensor creation failed");
1029
1030        let requires_grad =
1031            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
1032
1033        if requires_grad {
1034            let weight_var = self.weight.variable();
1035            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
1036
1037            let grad_fn = GradFn::new(ConvTranspose2dBackward::new(
1038                input.grad_fn().cloned(),
1039                weight_var.grad_fn().cloned(),
1040                bias_grad_fn,
1041                input_data,
1042                weight_data,
1043                input_shape,
1044                self.in_channels,
1045                self.out_channels,
1046                self.kernel_size,
1047                self.stride,
1048                self.padding,
1049                self.output_padding,
1050                self.bias.is_some(),
1051            ));
1052            Variable::from_operation(output_tensor, grad_fn, true)
1053        } else {
1054            Variable::new(output_tensor, false)
1055        }
1056    }
1057
1058    fn parameters(&self) -> Vec<Parameter> {
1059        let mut params = vec![self.weight.clone()];
1060        if let Some(ref bias) = self.bias {
1061            params.push(bias.clone());
1062        }
1063        params
1064    }
1065
1066    fn named_parameters(&self) -> HashMap<String, Parameter> {
1067        let mut params = HashMap::new();
1068        params.insert("weight".to_string(), self.weight.clone());
1069        if let Some(ref bias) = self.bias {
1070            params.insert("bias".to_string(), bias.clone());
1071        }
1072        params
1073    }
1074
1075    fn name(&self) -> &'static str {
1076        "ConvTranspose2d"
1077    }
1078}
1079
1080// =============================================================================
1081// Tests
1082// =============================================================================
1083
1084#[cfg(test)]
1085mod tests {
1086    use super::*;
1087
1088    #[test]
1089    fn test_conv1d_creation() {
1090        let conv = Conv1d::new(3, 16, 3);
1091        assert_eq!(conv.in_channels, 3);
1092        assert_eq!(conv.out_channels, 16);
1093        assert_eq!(conv.kernel_size, 3);
1094    }
1095
1096    #[test]
1097    fn test_conv1d_forward() {
1098        let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1099        let input = Variable::new(
1100            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5])
1101                .expect("tensor creation failed"),
1102            false,
1103        );
1104        let output = conv.forward(&input);
1105        assert_eq!(output.shape(), vec![1, 1, 5]);
1106    }
1107
1108    #[test]
1109    fn test_conv1d_backward() {
1110        let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1111        let input = Variable::new(
1112            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5])
1113                .expect("tensor creation failed"),
1114            true,
1115        );
1116        let output = conv.forward(&input);
1117        let loss = output.sum();
1118        loss.backward();
1119
1120        // Input should have gradient (not None)
1121        assert!(
1122            input.grad().is_some(),
1123            "Conv1d: input gradient should flow through backward pass"
1124        );
1125        let grad = input.grad().unwrap();
1126        assert_eq!(grad.shape(), &[1, 1, 5]);
1127    }
1128
1129    #[test]
1130    fn test_conv2d_creation() {
1131        let conv = Conv2d::new(3, 64, 3);
1132        assert_eq!(conv.in_channels, 3);
1133        assert_eq!(conv.out_channels, 64);
1134        assert_eq!(conv.kernel_size, (3, 3));
1135    }
1136
1137    #[test]
1138    fn test_conv2d_forward() {
1139        let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1140        let input = Variable::new(
1141            Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1142            false,
1143        );
1144        let output = conv.forward(&input);
1145        assert_eq!(output.shape(), vec![1, 1, 5, 5]);
1146    }
1147
1148    #[test]
1149    fn test_conv2d_backward() {
1150        let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1151        let input = Variable::new(
1152            Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1153            true,
1154        );
1155        let output = conv.forward(&input);
1156        let loss = output.sum();
1157        loss.backward();
1158
1159        assert!(
1160            input.grad().is_some(),
1161            "Conv2d: input gradient should flow through backward pass"
1162        );
1163        let grad = input.grad().unwrap();
1164        assert_eq!(grad.shape(), &[1, 1, 5, 5]);
1165
1166        // Weight should also have gradient
1167        let w_grad = conv.weight.grad();
1168        assert!(
1169            w_grad.is_some(),
1170            "Conv2d: weight gradient should be computed"
1171        );
1172    }
1173
1174    #[test]
1175    fn test_conv2d_parameters() {
1176        let conv = Conv2d::new(3, 64, 3);
1177        let params = conv.parameters();
1178        assert_eq!(params.len(), 2); // weight + bias
1179    }
1180
1181    #[test]
1182    fn test_conv2d_grouped() {
1183        // Depthwise: groups = in_channels = out_channels
1184        let conv = Conv2d::depthwise(4, 3);
1185        assert_eq!(conv.groups, 4);
1186        assert_eq!(conv.in_channels, 4);
1187        assert_eq!(conv.out_channels, 4);
1188
1189        let input = Variable::new(
1190            Tensor::from_vec(vec![1.0; 4 * 5 * 5], &[1, 4, 5, 5]).expect("tensor creation failed"),
1191            false,
1192        );
1193        let output = conv.forward(&input);
1194        assert_eq!(output.shape(), vec![1, 4, 5, 5]);
1195    }
1196
1197    #[test]
1198    fn test_conv_transpose2d_forward() {
1199        let conv_t = ConvTranspose2d::with_options(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false);
1200        let input = Variable::new(
1201            Tensor::from_vec(vec![1.0; 4], &[1, 1, 2, 2]).expect("tensor creation failed"),
1202            false,
1203        );
1204        let output = conv_t.forward(&input);
1205        // H_out = (2-1)*2 - 2*1 + 3 + 1 = 4
1206        assert_eq!(output.shape(), vec![1, 1, 4, 4]);
1207    }
1208
1209    #[test]
1210    fn test_conv_transpose2d_backward() {
1211        let conv_t = ConvTranspose2d::new(1, 1, 3);
1212        let input = Variable::new(
1213            Tensor::from_vec(vec![1.0; 9], &[1, 1, 3, 3]).expect("tensor creation failed"),
1214            true,
1215        );
1216        let output = conv_t.forward(&input);
1217        let loss = output.sum();
1218        loss.backward();
1219
1220        assert!(
1221            input.grad().is_some(),
1222            "ConvTranspose2d: input gradient should flow through backward"
1223        );
1224    }
1225
1226    // =========================================================================
1227    // Conv1d Comprehensive
1228    // =========================================================================
1229
1230    #[test]
1231    fn test_conv1d_with_padding_and_stride() {
1232        let conv = Conv1d::with_options(1, 4, 3, 2, 1, true);
1233        let input = Variable::new(Tensor::from_vec(vec![1.0; 16], &[1, 1, 16]).unwrap(), true);
1234        let output = conv.forward(&input);
1235        // L_out = (16 + 2*1 - 3) / 2 + 1 = 8
1236        assert_eq!(output.shape(), vec![1, 4, 8]);
1237
1238        output.sum().backward();
1239        let grad = input.grad().expect("Conv1d should propagate gradients");
1240        assert_eq!(grad.shape(), &[1, 1, 16]);
1241        assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1242    }
1243
1244    #[test]
1245    fn test_conv1d_multi_channel() {
1246        let conv = Conv1d::new(3, 8, 5); // 3 input channels, 8 output, kernel 5
1247        let input = Variable::new(
1248            Tensor::from_vec(vec![0.5; 2 * 3 * 20], &[2, 3, 20]).unwrap(),
1249            false,
1250        );
1251        let output = conv.forward(&input);
1252        // L_out = (20 - 5) / 1 + 1 = 16 (no padding)
1253        assert_eq!(output.shape(), vec![2, 8, 16]);
1254    }
1255
1256    // =========================================================================
1257    // Conv2d Grouped — Correctness
1258    // =========================================================================
1259
1260    #[test]
1261    fn test_conv2d_grouped_gradient_flow() {
1262        let conv = Conv2d::depthwise(4, 3);
1263        let input = Variable::new(
1264            Tensor::from_vec(vec![1.0; 4 * 8 * 8], &[1, 4, 8, 8]).unwrap(),
1265            true,
1266        );
1267        let output = conv.forward(&input);
1268        output.sum().backward();
1269
1270        let grad = input
1271            .grad()
1272            .expect("Grouped conv should propagate gradients");
1273        assert_eq!(grad.shape(), &[1, 4, 8, 8]);
1274        assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1275
1276        // Parameters should also get gradients
1277        for p in conv.parameters() {
1278            let g = p.grad().expect("Conv params should have gradients");
1279            assert!(g.to_vec().iter().any(|v| v.abs() > 0.0));
1280        }
1281    }
1282
1283    #[test]
1284    fn test_conv2d_groups_two() {
1285        // 2 groups: 4 input channels split into 2 groups of 2
1286        let conv = Conv2d::with_groups(4, 8, (3, 3), (1, 1), (1, 1), true, 2);
1287        let input = Variable::new(
1288            Tensor::from_vec(vec![1.0; 4 * 6 * 6], &[1, 4, 6, 6]).unwrap(),
1289            false,
1290        );
1291        let output = conv.forward(&input);
1292        assert_eq!(output.shape(), vec![1, 8, 6, 6]);
1293    }
1294
1295    #[test]
1296    fn test_conv2d_depthwise_separable_pattern() {
1297        // Depthwise separable: depthwise conv + pointwise conv (standard MobileNet pattern)
1298        let dw = Conv2d::depthwise(16, 3); // 16 channels, 3x3 kernel
1299        let pw = Conv2d::with_options(16, 32, (1, 1), (1, 1), (0, 0), true); // pointwise
1300
1301        let input = Variable::new(
1302            Tensor::from_vec(vec![1.0; 16 * 8 * 8], &[1, 16, 8, 8]).unwrap(),
1303            true,
1304        );
1305        let dw_out = dw.forward(&input);
1306        assert_eq!(dw_out.shape(), vec![1, 16, 8, 8]);
1307
1308        let pw_out = pw.forward(&dw_out);
1309        assert_eq!(pw_out.shape(), vec![1, 32, 8, 8]);
1310
1311        // Full gradient flow through both
1312        pw_out.sum().backward();
1313        let grad = input
1314            .grad()
1315            .expect("Should propagate through depthwise separable");
1316        assert_eq!(grad.shape(), &[1, 16, 8, 8]);
1317    }
1318
1319    // =========================================================================
1320    // ConvTranspose2d Comprehensive
1321    // =========================================================================
1322
1323    #[test]
1324    fn test_conv_transpose2d_upsamples() {
1325        // ConvTranspose2d with stride=2 should roughly double spatial dims
1326        let conv_t = ConvTranspose2d::with_options(1, 1, (4, 4), (2, 2), (1, 1), (0, 0), true);
1327        let input = Variable::new(
1328            Tensor::from_vec(vec![1.0; 4 * 4], &[1, 1, 4, 4]).unwrap(),
1329            false,
1330        );
1331        let output = conv_t.forward(&input);
1332        // H_out = (4-1)*2 - 2*1 + 4 + 0 = 8
1333        assert_eq!(output.shape(), vec![1, 1, 8, 8]);
1334    }
1335
1336    #[test]
1337    fn test_conv_transpose2d_gradient_correctness() {
1338        let conv_t = ConvTranspose2d::new(2, 4, 3);
1339        let input = Variable::new(
1340            Tensor::from_vec(vec![0.5; 2 * 4 * 4], &[1, 2, 4, 4]).unwrap(),
1341            true,
1342        );
1343        let output = conv_t.forward(&input);
1344        output.sum().backward();
1345
1346        let grad = input.grad().unwrap();
1347        assert_eq!(grad.shape(), &[1, 2, 4, 4]);
1348        assert!(grad.to_vec().iter().all(|g| g.is_finite()));
1349        assert!(grad.to_vec().iter().any(|g| g.abs() > 0.0));
1350
1351        // Weight params should also have gradients
1352        for p in conv_t.parameters() {
1353            assert!(p.grad().is_some(), "ConvTranspose2d params need gradients");
1354        }
1355    }
1356
1357    #[test]
1358    fn test_conv_transpose2d_multi_channel() {
1359        let conv_t = ConvTranspose2d::new(8, 16, 3);
1360        let input = Variable::new(
1361            Tensor::from_vec(vec![1.0; 2 * 8 * 4 * 4], &[2, 8, 4, 4]).unwrap(),
1362            false,
1363        );
1364        let output = conv_t.forward(&input);
1365        assert_eq!(output.shape()[0], 2); // batch
1366        assert_eq!(output.shape()[1], 16); // out_channels
1367    }
1368}