Skip to main content

axonml_nn/layers/
conv.rs

1//! Convolutional Layers - 1D and 2D Convolutions
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/conv.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18
19use axonml_autograd::Variable;
20use axonml_autograd::functions::{
21    Conv1dBackward, Conv2dBackward, ConvTranspose2dBackward, GroupedConv2dBackward,
22};
23use axonml_autograd::grad_fn::GradFn;
24use axonml_autograd::no_grad::is_grad_enabled;
25use axonml_tensor::Tensor;
26use rayon::prelude::*;
27
28use crate::init::{kaiming_uniform, zeros};
29use crate::module::Module;
30use crate::parameter::Parameter;
31
32// =============================================================================
33// Conv1d
34// =============================================================================
35
36/// Applies a 1D convolution over an input signal.
37///
38/// # Shape
39/// - Input: (N, C_in, L)
40/// - Output: (N, C_out, L_out)
41///
42/// where L_out = (L + 2*padding - kernel_size) / stride + 1
43pub struct Conv1d {
44    /// Weight tensor of shape (out_channels, in_channels, kernel_size).
45    pub weight: Parameter,
46    /// Bias tensor of shape (out_channels).
47    pub bias: Option<Parameter>,
48    /// Number of input channels.
49    in_channels: usize,
50    /// Number of output channels.
51    out_channels: usize,
52    /// Size of the convolving kernel.
53    kernel_size: usize,
54    /// Stride of the convolution.
55    stride: usize,
56    /// Zero-padding added to both sides.
57    padding: usize,
58}
59
60impl Conv1d {
61    /// Creates a new Conv1d layer.
62    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
63        Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
64    }
65
66    /// Creates a Conv1d layer with all options.
67    pub fn with_options(
68        in_channels: usize,
69        out_channels: usize,
70        kernel_size: usize,
71        stride: usize,
72        padding: usize,
73        bias: bool,
74    ) -> Self {
75        // Initialize weights
76        let fan_in = in_channels * kernel_size;
77        let weight_data = kaiming_uniform(out_channels, fan_in);
78        let weight_reshaped = weight_data
79            .reshape(&[
80                out_channels as isize,
81                in_channels as isize,
82                kernel_size as isize,
83            ])
84            .unwrap();
85        let weight = Parameter::named("weight", weight_reshaped, true);
86
87        let bias_param = if bias {
88            Some(Parameter::named("bias", zeros(&[out_channels]), true))
89        } else {
90            None
91        };
92
93        Self {
94            weight,
95            bias: bias_param,
96            in_channels,
97            out_channels,
98            kernel_size,
99            stride,
100            padding,
101        }
102    }
103}
104
105impl Module for Conv1d {
106    fn forward(&self, input: &Variable) -> Variable {
107        let input_shape = input.shape();
108        let batch_size = input_shape[0];
109        let in_length = input_shape[2];
110
111        let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
112
113        let input_data = input.data();
114        let weight_data = self.weight.data();
115
116        // GPU-resident fast path: reshape [B,C,L] → [B,C,L,1], use Conv2d CUDA pipeline,
117        // then reshape output [B,Cout,Lout,1] → [B,Cout,Lout].
118        #[cfg(feature = "cuda")]
119        if input_data.device().is_gpu() {
120            // Auto-migrate weights to GPU if needed
121            let input_dev = input_data.device();
122            if !weight_data.device().is_gpu() {
123                self.weight.to_device(input_dev);
124                if let Some(ref b) = self.bias {
125                    b.to_device(input_dev);
126                }
127            }
128            let weight_data = self.weight.data();
129
130            // Reshape input [B, Cin, L] → [B, Cin, L, 1]
131            let input_4d = input_data
132                .reshape(&[
133                    batch_size as isize,
134                    self.in_channels as isize,
135                    in_length as isize,
136                    1,
137                ])
138                .unwrap();
139
140            // Reshape weight [Cout, Cin, K] → [Cout, Cin, K, 1]
141            let weight_4d = weight_data
142                .reshape(&[
143                    self.out_channels as isize,
144                    self.in_channels as isize,
145                    self.kernel_size as isize,
146                    1,
147                ])
148                .unwrap();
149
150            let bias_tensor = self.bias.as_ref().map(|b| b.data());
151            let gpu_output = input_4d.conv2d_cuda(
152                &weight_4d,
153                bias_tensor.as_ref(),
154                (self.stride, 1),
155                (self.padding, 0),
156            );
157
158            if let Some(output_4d) = gpu_output {
159                // Reshape output [B, Cout, Lout, 1] → [B, Cout, Lout]
160                let output_tensor = output_4d
161                    .reshape(&[
162                        batch_size as isize,
163                        self.out_channels as isize,
164                        out_length as isize,
165                    ])
166                    .unwrap();
167
168                let requires_grad =
169                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
170                if requires_grad {
171                    let weight_var = self.weight.variable();
172                    let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
173
174                    let grad_fn = GradFn::new(Conv1dBackward::new(
175                        input.grad_fn().cloned(),
176                        weight_var.grad_fn().cloned(),
177                        bias_grad_fn,
178                        input_data,
179                        weight_data,
180                        input_shape,
181                        self.in_channels,
182                        self.out_channels,
183                        self.kernel_size,
184                        self.stride,
185                        self.padding,
186                        self.bias.is_some(),
187                    ));
188                    return Variable::from_operation(output_tensor, grad_fn, true);
189                } else {
190                    return Variable::new(output_tensor, false);
191                }
192            }
193            // Fall through to CPU path if GPU conv failed
194        }
195
196        let input_vec = input_data.to_vec();
197        let weight_vec = weight_data.to_vec();
198
199        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
200
201        for b in 0..batch_size {
202            for oc in 0..self.out_channels {
203                for ol in 0..out_length {
204                    let mut sum = 0.0f32;
205                    let in_start = ol * self.stride;
206
207                    for ic in 0..self.in_channels {
208                        for k in 0..self.kernel_size {
209                            let in_idx = in_start + k;
210                            if in_idx < self.padding || in_idx >= in_length + self.padding {
211                                continue;
212                            }
213                            let actual_idx = in_idx - self.padding;
214
215                            let input_idx =
216                                b * self.in_channels * in_length + ic * in_length + actual_idx;
217                            let weight_idx = oc * self.in_channels * self.kernel_size
218                                + ic * self.kernel_size
219                                + k;
220
221                            sum += input_vec[input_idx] * weight_vec[weight_idx];
222                        }
223                    }
224
225                    if let Some(ref bias) = self.bias {
226                        sum += bias.data().to_vec()[oc];
227                    }
228
229                    let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
230                    output_data[output_idx] = sum;
231                }
232            }
233        }
234
235        let output_tensor =
236            Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length]).expect("tensor creation failed");
237
238        let requires_grad =
239            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
240
241        if requires_grad {
242            let weight_var = self.weight.variable();
243            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
244
245            let grad_fn = GradFn::new(Conv1dBackward::new(
246                input.grad_fn().cloned(),
247                weight_var.grad_fn().cloned(),
248                bias_grad_fn,
249                input_data,
250                weight_data,
251                input_shape,
252                self.in_channels,
253                self.out_channels,
254                self.kernel_size,
255                self.stride,
256                self.padding,
257                self.bias.is_some(),
258            ));
259            Variable::from_operation(output_tensor, grad_fn, true)
260        } else {
261            Variable::new(output_tensor, false)
262        }
263    }
264
265    fn parameters(&self) -> Vec<Parameter> {
266        let mut params = vec![self.weight.clone()];
267        if let Some(ref bias) = self.bias {
268            params.push(bias.clone());
269        }
270        params
271    }
272
273    fn named_parameters(&self) -> HashMap<String, Parameter> {
274        let mut params = HashMap::new();
275        params.insert("weight".to_string(), self.weight.clone());
276        if let Some(ref bias) = self.bias {
277            params.insert("bias".to_string(), bias.clone());
278        }
279        params
280    }
281
282    fn name(&self) -> &'static str {
283        "Conv1d"
284    }
285}
286
287// =============================================================================
288// Conv2d
289// =============================================================================
290
291/// Applies a 2D convolution over an input image.
292///
293/// # Shape
294/// - Input: (N, C_in, H, W)
295/// - Output: (N, C_out, H_out, W_out)
296///
297/// where H_out = (H + 2*padding - kernel_size) / stride + 1
298pub struct Conv2d {
299    /// Weight tensor of shape (out_channels, in_channels, kernel_h, kernel_w).
300    pub weight: Parameter,
301    /// Bias tensor of shape (out_channels).
302    pub bias: Option<Parameter>,
303    /// Number of input channels.
304    in_channels: usize,
305    /// Number of output channels.
306    out_channels: usize,
307    /// Size of the convolving kernel (height, width).
308    kernel_size: (usize, usize),
309    /// Stride of the convolution (height, width).
310    stride: (usize, usize),
311    /// Zero-padding added to both sides (height, width).
312    padding: (usize, usize),
313    /// Number of groups for grouped convolution.
314    groups: usize,
315}
316
317impl Conv2d {
318    /// Creates a new Conv2d layer with square kernel.
319    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
320        Self::with_options(
321            in_channels,
322            out_channels,
323            (kernel_size, kernel_size),
324            (1, 1),
325            (0, 0),
326            true,
327        )
328    }
329
330    /// Creates a Conv2d layer with all options.
331    pub fn with_options(
332        in_channels: usize,
333        out_channels: usize,
334        kernel_size: (usize, usize),
335        stride: (usize, usize),
336        padding: (usize, usize),
337        bias: bool,
338    ) -> Self {
339        Self::with_groups(
340            in_channels,
341            out_channels,
342            kernel_size,
343            stride,
344            padding,
345            bias,
346            1,
347        )
348    }
349
350    /// Creates a Conv2d layer with grouped convolution support.
351    ///
352    /// When `groups == in_channels` and `out_channels == in_channels`, this is
353    /// a depthwise convolution.
354    pub fn with_groups(
355        in_channels: usize,
356        out_channels: usize,
357        kernel_size: (usize, usize),
358        stride: (usize, usize),
359        padding: (usize, usize),
360        bias: bool,
361        groups: usize,
362    ) -> Self {
363        assert!(
364            in_channels % groups == 0,
365            "in_channels must be divisible by groups"
366        );
367        assert!(
368            out_channels % groups == 0,
369            "out_channels must be divisible by groups"
370        );
371
372        let (kh, kw) = kernel_size;
373        let in_channels_per_group = in_channels / groups;
374        let fan_in = in_channels_per_group * kh * kw;
375
376        let weight_data = kaiming_uniform(out_channels, fan_in);
377        let weight_reshaped = weight_data
378            .reshape(&[
379                out_channels as isize,
380                in_channels_per_group as isize,
381                kh as isize,
382                kw as isize,
383            ])
384            .unwrap();
385        let weight = Parameter::named("weight", weight_reshaped, true);
386
387        let bias_param = if bias {
388            Some(Parameter::named("bias", zeros(&[out_channels]), true))
389        } else {
390            None
391        };
392
393        Self {
394            weight,
395            bias: bias_param,
396            in_channels,
397            out_channels,
398            kernel_size,
399            stride,
400            padding,
401            groups,
402        }
403    }
404
405    /// Creates a depthwise convolution (groups = in_channels).
406    pub fn depthwise(channels: usize, kernel_size: usize) -> Self {
407        Self::with_groups(
408            channels,
409            channels,
410            (kernel_size, kernel_size),
411            (1, 1),
412            (kernel_size / 2, kernel_size / 2),
413            true,
414            channels,
415        )
416    }
417}
418
419// =============================================================================
420// im2col + GEMM Conv2d Implementation
421// =============================================================================
422
423/// Unfold input patches into a column matrix (im2col).
424///
425/// Input: `[C_in, H, W]` (one batch element, one group's channels)
426/// Output: `[C_in * kH * kW, out_H * out_W]`
427fn im2col(
428    input: &[f32],
429    channels: usize,
430    height: usize,
431    width: usize,
432    kernel_h: usize,
433    kernel_w: usize,
434    pad_h: usize,
435    pad_w: usize,
436    stride_h: usize,
437    stride_w: usize,
438    out_h: usize,
439    out_w: usize,
440) -> Vec<f32> {
441    let col_h = channels * kernel_h * kernel_w;
442    let col_w = out_h * out_w;
443    let mut col = vec![0.0f32; col_h * col_w];
444    let hw = height * width;
445    let kk = kernel_h * kernel_w;
446    let h_signed = height as isize;
447    let w_signed = width as isize;
448    let pad_h_s = pad_h as isize;
449    let pad_w_s = pad_w as isize;
450
451    // Fused single-pass: iterate linearly over output col matrix
452    // col_row = c * kH * kW + kh_off * kW + kw_off
453    // col_col = oh * out_w + ow
454    for col_row in 0..col_h {
455        let c = col_row / kk;
456        let k_idx = col_row % kk;
457        let kh_off = k_idx / kernel_w;
458        let kw_off = k_idx % kernel_w;
459        let input_c = c * hw;
460        let col_base = col_row * col_w;
461
462        for oh in 0..out_h {
463            let h_in = (oh * stride_h + kh_off) as isize - pad_h_s;
464            if h_in < 0 || h_in >= h_signed {
465                continue;
466            }
467            let input_row = input_c + h_in as usize * width;
468            let col_row_base = col_base + oh * out_w;
469
470            for ow in 0..out_w {
471                let w_in = (ow * stride_w + kw_off) as isize - pad_w_s;
472                if w_in >= 0 && w_in < w_signed {
473                    let col_idx = col_row_base + ow;
474                    let inp_idx = input_row + w_in as usize;
475                    debug_assert!(col_idx < col.len(), "im2col fwd col OOB: {col_idx} >= {}", col.len());
476                    debug_assert!(inp_idx < input.len(), "im2col fwd input OOB: {inp_idx} >= {}", input.len());
477                    unsafe {
478                        *col.get_unchecked_mut(col_idx) =
479                            *input.get_unchecked(inp_idx);
480                    }
481                }
482            }
483        }
484    }
485
486    col
487}
488
489/// Conv2d forward using im2col + matmul. Supports groups.
490fn conv2d_im2col(
491    input: &[f32],
492    weight: &[f32],
493    bias: Option<&[f32]>,
494    batch_size: usize,
495    in_channels: usize,
496    in_height: usize,
497    in_width: usize,
498    out_channels: usize,
499    kh: usize,
500    kw: usize,
501    sh: usize,
502    sw: usize,
503    ph: usize,
504    pw: usize,
505    groups: usize,
506) -> Vec<f32> {
507    let out_h = (in_height + 2 * ph - kh) / sh + 1;
508    let out_w = (in_width + 2 * pw - kw) / sw + 1;
509    let in_channels_per_group = in_channels / groups;
510    let out_channels_per_group = out_channels / groups;
511    let col_h = in_channels_per_group * kh * kw;
512    let col_w = out_h * out_w;
513    let spatial = out_h * out_w;
514    let in_spatial = in_height * in_width;
515
516    // Parallel: each batch element produces its own output slice
517    let out_per_batch = out_channels * spatial;
518    let per_batch: Vec<Vec<f32>> = (0..batch_size)
519        .into_par_iter()
520        .map(|b| {
521            let mut batch_out = vec![0.0f32; out_per_batch];
522
523            for g in 0..groups {
524                let ic_start = g * in_channels_per_group;
525                let oc_start = g * out_channels_per_group;
526
527                // Extract input for this batch+group
528                let in_offset = b * in_channels * in_spatial + ic_start * in_spatial;
529                let input_slice = &input[in_offset..in_offset + in_channels_per_group * in_spatial];
530
531                // im2col
532                let col = im2col(
533                    input_slice,
534                    in_channels_per_group,
535                    in_height,
536                    in_width,
537                    kh,
538                    kw,
539                    ph,
540                    pw,
541                    sh,
542                    sw,
543                    out_h,
544                    out_w,
545                );
546
547                // Weight for this group
548                let w_offset = oc_start * in_channels_per_group * kh * kw;
549                let w_size = out_channels_per_group * col_h;
550                let weight_slice = &weight[w_offset..w_offset + w_size];
551
552                // GEMM via Tensor::matmul
553                let w_tensor =
554                    Tensor::from_vec(weight_slice.to_vec(), &[out_channels_per_group, col_h])
555                        .unwrap();
556                let col_tensor = Tensor::from_vec(col, &[col_h, col_w]).expect("tensor creation failed");
557                let result = w_tensor.matmul(&col_tensor).expect("matmul failed");
558                let result_vec = result.to_vec();
559
560                // Copy to output with bias
561                let out_offset = oc_start * spatial;
562                for oc_local in 0..out_channels_per_group {
563                    let oc = oc_start + oc_local;
564                    let bias_val = bias.map_or(0.0, |bv| bv[oc]);
565                    let src_start = oc_local * col_w;
566                    let dst_start = out_offset + oc_local * spatial;
567                    if bias_val == 0.0 {
568                        batch_out[dst_start..dst_start + spatial]
569                            .copy_from_slice(&result_vec[src_start..src_start + spatial]);
570                    } else {
571                        for i in 0..spatial {
572                            batch_out[dst_start + i] = result_vec[src_start + i] + bias_val;
573                        }
574                    }
575                }
576            }
577
578            batch_out
579        })
580        .collect();
581
582    // Flatten per-batch results into single output
583    let mut output = Vec::with_capacity(batch_size * out_per_batch);
584    for batch_out in per_batch {
585        output.extend_from_slice(&batch_out);
586    }
587    output
588}
589
590impl Module for Conv2d {
591    fn forward(&self, input: &Variable) -> Variable {
592        let input_shape = input.shape();
593        let batch_size = input_shape[0];
594        let in_height = input_shape[2];
595        let in_width = input_shape[3];
596
597        let (kh, kw) = self.kernel_size;
598        let (sh, sw) = self.stride;
599        let (ph, pw) = self.padding;
600
601        let out_height = (in_height + 2 * ph - kh) / sh + 1;
602        let out_width = (in_width + 2 * pw - kw) / sw + 1;
603
604        let input_data = input.data();
605        let weight_data = self.weight.data();
606
607        // GPU-resident fast path: when input is already on GPU, do everything on GPU
608        // without any CPU↔GPU copies.
609        #[cfg(feature = "cuda")]
610        if input_data.device().is_gpu() {
611            // Auto-migrate weights to GPU if needed (one-time cost, cached via Arc)
612            let input_dev = input_data.device();
613            if !weight_data.device().is_gpu() {
614                self.weight.to_device(input_dev);
615                if let Some(ref b) = self.bias {
616                    b.to_device(input_dev);
617                }
618            }
619            let weight_data = self.weight.data();
620
621            // Try cuDNN first (fastest path), fall back to im2col+GEMM
622            #[cfg(feature = "cudnn")]
623            let cudnn_output = {
624                let bias_tensor = self.bias.as_ref().map(|b| b.data());
625                input_data.conv2d_cudnn(
626                    &weight_data,
627                    bias_tensor.as_ref(),
628                    self.stride,
629                    self.padding,
630                    self.groups,
631                )
632            };
633            #[cfg(not(feature = "cudnn"))]
634            let cudnn_output: Option<axonml_tensor::Tensor<f32>> = None;
635
636            let gpu_output = if cudnn_output.is_some() {
637                cudnn_output
638            } else if self.groups == 1 {
639                // Standard convolution: single im2col + GEMM
640                let bias_tensor = self.bias.as_ref().map(|b| b.data());
641                input_data.conv2d_cuda(
642                    &weight_data,
643                    bias_tensor.as_ref(),
644                    self.stride,
645                    self.padding,
646                )
647            } else {
648                // Grouped convolution: run per-group im2col + GEMM on GPU
649                input_data.conv2d_grouped_cuda(
650                    &weight_data,
651                    self.bias.as_ref().map(|b| b.data()).as_ref(),
652                    self.stride,
653                    self.padding,
654                    self.groups,
655                )
656            };
657
658            if let Some(output_tensor) = gpu_output {
659                let requires_grad =
660                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
661                if requires_grad {
662                    let weight_var = self.weight.variable();
663                    let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
664                    if self.groups == 1 {
665                        let grad_fn = GradFn::new(Conv2dBackward::new(
666                            input.grad_fn().cloned(),
667                            weight_var.grad_fn().cloned(),
668                            bias_grad_fn,
669                            input_data,
670                            weight_data,
671                            input_shape,
672                            self.in_channels,
673                            self.out_channels,
674                            self.kernel_size,
675                            self.stride,
676                            self.padding,
677                            self.bias.is_some(),
678                        ));
679                        return Variable::from_operation(output_tensor, grad_fn, true);
680                    } else {
681                        let grad_fn = GradFn::new(GroupedConv2dBackward::new(
682                            input.grad_fn().cloned(),
683                            weight_var.grad_fn().cloned(),
684                            bias_grad_fn,
685                            input_data,
686                            weight_data,
687                            input_shape,
688                            self.in_channels,
689                            self.out_channels,
690                            self.kernel_size,
691                            self.stride,
692                            self.padding,
693                            self.groups,
694                            self.bias.is_some(),
695                        ));
696                        return Variable::from_operation(output_tensor, grad_fn, true);
697                    }
698                } else {
699                    return Variable::new(output_tensor, false);
700                }
701            }
702            // Fall through to CPU path if GPU conv failed
703        }
704
705        let input_vec = input_data.to_vec();
706        let weight_vec = weight_data.to_vec();
707
708        // Try GPU im2col+GEMM for groups=1 when data is on CPU but GPU is available
709        let conv_flops = self.out_channels * self.in_channels * kh * kw * out_height * out_width;
710        let output_data = if self.groups == 1 && conv_flops >= 500_000 {
711            let bias_vec = self.bias.as_ref().map(|b| b.data().to_vec());
712            let gpu_result = axonml_core::backends::cuda::cuda_conv2d_forward(
713                &input_vec,
714                &weight_vec,
715                bias_vec.as_deref(),
716                batch_size,
717                self.in_channels,
718                in_height,
719                in_width,
720                self.out_channels,
721                kh,
722                kw,
723                sh,
724                sw,
725                ph,
726                pw,
727            );
728
729            if let Some(result) = gpu_result {
730                result
731            } else {
732                conv2d_im2col(
733                    &input_vec,
734                    &weight_vec,
735                    self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
736                    batch_size,
737                    self.in_channels,
738                    in_height,
739                    in_width,
740                    self.out_channels,
741                    kh,
742                    kw,
743                    sh,
744                    sw,
745                    ph,
746                    pw,
747                    self.groups,
748                )
749            }
750        } else {
751            conv2d_im2col(
752                &input_vec,
753                &weight_vec,
754                self.bias.as_ref().map(|b| b.data().to_vec()).as_deref(),
755                batch_size,
756                self.in_channels,
757                in_height,
758                in_width,
759                self.out_channels,
760                kh,
761                kw,
762                sh,
763                sw,
764                ph,
765                pw,
766                self.groups,
767            )
768        };
769
770        let output_tensor = Tensor::from_vec(
771            output_data,
772            &[batch_size, self.out_channels, out_height, out_width],
773        )
774        .unwrap();
775
776        let requires_grad =
777            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
778
779        if requires_grad && self.groups == 1 {
780            // Full backward pass for standard convolution
781            let weight_var = self.weight.variable();
782            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
783
784            let grad_fn = GradFn::new(Conv2dBackward::new(
785                input.grad_fn().cloned(),
786                weight_var.grad_fn().cloned(),
787                bias_grad_fn,
788                input_data,
789                weight_data,
790                input_shape,
791                self.in_channels,
792                self.out_channels,
793                self.kernel_size,
794                self.stride,
795                self.padding,
796                self.bias.is_some(),
797            ));
798            Variable::from_operation(output_tensor, grad_fn, true)
799        } else if requires_grad {
800            // Grouped convolution backward (depthwise separable, etc.)
801            let weight_var = self.weight.variable();
802            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
803
804            let grad_fn = GradFn::new(GroupedConv2dBackward::new(
805                input.grad_fn().cloned(),
806                weight_var.grad_fn().cloned(),
807                bias_grad_fn,
808                input_data,
809                weight_data,
810                input_shape,
811                self.in_channels,
812                self.out_channels,
813                self.kernel_size,
814                self.stride,
815                self.padding,
816                self.groups,
817                self.bias.is_some(),
818            ));
819            Variable::from_operation(output_tensor, grad_fn, true)
820        } else {
821            Variable::new(output_tensor, false)
822        }
823    }
824
825    fn parameters(&self) -> Vec<Parameter> {
826        let mut params = vec![self.weight.clone()];
827        if let Some(ref bias) = self.bias {
828            params.push(bias.clone());
829        }
830        params
831    }
832
833    fn named_parameters(&self) -> HashMap<String, Parameter> {
834        let mut params = HashMap::new();
835        params.insert("weight".to_string(), self.weight.clone());
836        if let Some(ref bias) = self.bias {
837            params.insert("bias".to_string(), bias.clone());
838        }
839        params
840    }
841
842    fn name(&self) -> &'static str {
843        "Conv2d"
844    }
845}
846
847// =============================================================================
848// ConvTranspose2d
849// =============================================================================
850
851/// Applies a 2D transposed convolution (deconvolution) for upsampling.
852///
853/// # Shape
854/// - Input: (N, C_in, H, W)
855/// - Output: (N, C_out, H_out, W_out)
856///
857/// where H_out = (H - 1) * stride - 2*padding + kernel_size + output_padding
858pub struct ConvTranspose2d {
859    /// Weight tensor of shape (in_channels, out_channels, kernel_h, kernel_w).
860    pub weight: Parameter,
861    /// Bias tensor of shape (out_channels).
862    pub bias: Option<Parameter>,
863    in_channels: usize,
864    out_channels: usize,
865    kernel_size: (usize, usize),
866    stride: (usize, usize),
867    padding: (usize, usize),
868    output_padding: (usize, usize),
869}
870
871impl ConvTranspose2d {
872    /// Creates a new ConvTranspose2d layer with square kernel.
873    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
874        Self::with_options(
875            in_channels,
876            out_channels,
877            (kernel_size, kernel_size),
878            (1, 1),
879            (0, 0),
880            (0, 0),
881            true,
882        )
883    }
884
885    /// Creates a ConvTranspose2d layer with all options.
886    pub fn with_options(
887        in_channels: usize,
888        out_channels: usize,
889        kernel_size: (usize, usize),
890        stride: (usize, usize),
891        padding: (usize, usize),
892        output_padding: (usize, usize),
893        bias: bool,
894    ) -> Self {
895        let (kh, kw) = kernel_size;
896        let fan_in = in_channels * kh * kw;
897
898        let weight_data = kaiming_uniform(out_channels, fan_in);
899        let weight_reshaped = weight_data
900            .reshape(&[
901                in_channels as isize,
902                out_channels as isize,
903                kh as isize,
904                kw as isize,
905            ])
906            .unwrap();
907        let weight = Parameter::named("weight", weight_reshaped, true);
908
909        let bias_param = if bias {
910            Some(Parameter::named("bias", zeros(&[out_channels]), true))
911        } else {
912            None
913        };
914
915        Self {
916            weight,
917            bias: bias_param,
918            in_channels,
919            out_channels,
920            kernel_size,
921            stride,
922            padding,
923            output_padding,
924        }
925    }
926}
927
928impl Module for ConvTranspose2d {
929    fn forward(&self, input: &Variable) -> Variable {
930        let input_shape = input.shape();
931        let batch_size = input_shape[0];
932        let in_h = input_shape[2];
933        let in_w = input_shape[3];
934
935        let (kh, kw) = self.kernel_size;
936        let (sh, sw) = self.stride;
937        let (ph, pw) = self.padding;
938        let (oph, opw) = self.output_padding;
939
940        let out_h = (in_h - 1) * sh - 2 * ph + kh + oph;
941        let out_w = (in_w - 1) * sw - 2 * pw + kw + opw;
942
943        let input_data = input.data();
944        let weight_data = self.weight.data();
945        let input_vec = input_data.to_vec();
946        let weight_vec = weight_data.to_vec();
947
948        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_h * out_w];
949
950        // Transposed convolution: scatter input values through the kernel
951        for b in 0..batch_size {
952            for ic in 0..self.in_channels {
953                for ih in 0..in_h {
954                    for iw in 0..in_w {
955                        let in_idx =
956                            b * self.in_channels * in_h * in_w + ic * in_h * in_w + ih * in_w + iw;
957                        let in_val = input_vec[in_idx];
958
959                        for oc in 0..self.out_channels {
960                            for ki in 0..kh {
961                                for kj in 0..kw {
962                                    let oh_signed = (ih * sh + ki) as isize - ph as isize;
963                                    let ow_signed = (iw * sw + kj) as isize - pw as isize;
964
965                                    if oh_signed >= 0
966                                        && (oh_signed as usize) < out_h
967                                        && ow_signed >= 0
968                                        && (ow_signed as usize) < out_w
969                                    {
970                                        let oh = oh_signed as usize;
971                                        let ow = ow_signed as usize;
972                                        let out_idx = b * self.out_channels * out_h * out_w
973                                            + oc * out_h * out_w
974                                            + oh * out_w
975                                            + ow;
976                                        // weight: (in_channels, out_channels, kh, kw)
977                                        let w_idx = ic * self.out_channels * kh * kw
978                                            + oc * kh * kw
979                                            + ki * kw
980                                            + kj;
981                                        output_data[out_idx] += in_val * weight_vec[w_idx];
982                                    }
983                                }
984                            }
985                        }
986                    }
987                }
988            }
989        }
990
991        // Add bias
992        if let Some(ref bias) = self.bias {
993            let bias_vec = bias.data().to_vec();
994            for b in 0..batch_size {
995                for oc in 0..self.out_channels {
996                    for oh in 0..out_h {
997                        for ow in 0..out_w {
998                            let out_idx = b * self.out_channels * out_h * out_w
999                                + oc * out_h * out_w
1000                                + oh * out_w
1001                                + ow;
1002                            output_data[out_idx] += bias_vec[oc];
1003                        }
1004                    }
1005                }
1006            }
1007        }
1008
1009        let output_tensor =
1010            Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_h, out_w]).expect("tensor creation failed");
1011
1012        let requires_grad =
1013            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
1014
1015        if requires_grad {
1016            let weight_var = self.weight.variable();
1017            let bias_grad_fn = self.bias.as_ref().map(|b| b.variable().grad_fn().cloned());
1018
1019            let grad_fn = GradFn::new(ConvTranspose2dBackward::new(
1020                input.grad_fn().cloned(),
1021                weight_var.grad_fn().cloned(),
1022                bias_grad_fn,
1023                input_data,
1024                weight_data,
1025                input_shape,
1026                self.in_channels,
1027                self.out_channels,
1028                self.kernel_size,
1029                self.stride,
1030                self.padding,
1031                self.output_padding,
1032                self.bias.is_some(),
1033            ));
1034            Variable::from_operation(output_tensor, grad_fn, true)
1035        } else {
1036            Variable::new(output_tensor, false)
1037        }
1038    }
1039
1040    fn parameters(&self) -> Vec<Parameter> {
1041        let mut params = vec![self.weight.clone()];
1042        if let Some(ref bias) = self.bias {
1043            params.push(bias.clone());
1044        }
1045        params
1046    }
1047
1048    fn named_parameters(&self) -> HashMap<String, Parameter> {
1049        let mut params = HashMap::new();
1050        params.insert("weight".to_string(), self.weight.clone());
1051        if let Some(ref bias) = self.bias {
1052            params.insert("bias".to_string(), bias.clone());
1053        }
1054        params
1055    }
1056
1057    fn name(&self) -> &'static str {
1058        "ConvTranspose2d"
1059    }
1060}
1061
1062// =============================================================================
1063// Tests
1064// =============================================================================
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::*;
1069
1070    #[test]
1071    fn test_conv1d_creation() {
1072        let conv = Conv1d::new(3, 16, 3);
1073        assert_eq!(conv.in_channels, 3);
1074        assert_eq!(conv.out_channels, 16);
1075        assert_eq!(conv.kernel_size, 3);
1076    }
1077
1078    #[test]
1079    fn test_conv1d_forward() {
1080        let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1081        let input = Variable::new(
1082            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).expect("tensor creation failed"),
1083            false,
1084        );
1085        let output = conv.forward(&input);
1086        assert_eq!(output.shape(), vec![1, 1, 5]);
1087    }
1088
1089    #[test]
1090    fn test_conv1d_backward() {
1091        let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
1092        let input = Variable::new(
1093            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).expect("tensor creation failed"),
1094            true,
1095        );
1096        let output = conv.forward(&input);
1097        let loss = output.sum();
1098        loss.backward();
1099
1100        // Input should have gradient (not None)
1101        assert!(
1102            input.grad().is_some(),
1103            "Conv1d: input gradient should flow through backward pass"
1104        );
1105        let grad = input.grad().unwrap();
1106        assert_eq!(grad.shape(), &[1, 1, 5]);
1107    }
1108
1109    #[test]
1110    fn test_conv2d_creation() {
1111        let conv = Conv2d::new(3, 64, 3);
1112        assert_eq!(conv.in_channels, 3);
1113        assert_eq!(conv.out_channels, 64);
1114        assert_eq!(conv.kernel_size, (3, 3));
1115    }
1116
1117    #[test]
1118    fn test_conv2d_forward() {
1119        let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1120        let input = Variable::new(
1121            Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1122            false,
1123        );
1124        let output = conv.forward(&input);
1125        assert_eq!(output.shape(), vec![1, 1, 5, 5]);
1126    }
1127
1128    #[test]
1129    fn test_conv2d_backward() {
1130        let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
1131        let input = Variable::new(
1132            Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).expect("tensor creation failed"),
1133            true,
1134        );
1135        let output = conv.forward(&input);
1136        let loss = output.sum();
1137        loss.backward();
1138
1139        assert!(
1140            input.grad().is_some(),
1141            "Conv2d: input gradient should flow through backward pass"
1142        );
1143        let grad = input.grad().unwrap();
1144        assert_eq!(grad.shape(), &[1, 1, 5, 5]);
1145
1146        // Weight should also have gradient
1147        let w_grad = conv.weight.grad();
1148        assert!(
1149            w_grad.is_some(),
1150            "Conv2d: weight gradient should be computed"
1151        );
1152    }
1153
1154    #[test]
1155    fn test_conv2d_parameters() {
1156        let conv = Conv2d::new(3, 64, 3);
1157        let params = conv.parameters();
1158        assert_eq!(params.len(), 2); // weight + bias
1159    }
1160
1161    #[test]
1162    fn test_conv2d_grouped() {
1163        // Depthwise: groups = in_channels = out_channels
1164        let conv = Conv2d::depthwise(4, 3);
1165        assert_eq!(conv.groups, 4);
1166        assert_eq!(conv.in_channels, 4);
1167        assert_eq!(conv.out_channels, 4);
1168
1169        let input = Variable::new(
1170            Tensor::from_vec(vec![1.0; 4 * 5 * 5], &[1, 4, 5, 5]).expect("tensor creation failed"),
1171            false,
1172        );
1173        let output = conv.forward(&input);
1174        assert_eq!(output.shape(), vec![1, 4, 5, 5]);
1175    }
1176
1177    #[test]
1178    fn test_conv_transpose2d_forward() {
1179        let conv_t = ConvTranspose2d::with_options(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false);
1180        let input = Variable::new(
1181            Tensor::from_vec(vec![1.0; 4], &[1, 1, 2, 2]).expect("tensor creation failed"),
1182            false,
1183        );
1184        let output = conv_t.forward(&input);
1185        // H_out = (2-1)*2 - 2*1 + 3 + 1 = 4
1186        assert_eq!(output.shape(), vec![1, 1, 4, 4]);
1187    }
1188
1189    #[test]
1190    fn test_conv_transpose2d_backward() {
1191        let conv_t = ConvTranspose2d::new(1, 1, 3);
1192        let input = Variable::new(Tensor::from_vec(vec![1.0; 9], &[1, 1, 3, 3]).expect("tensor creation failed"), true);
1193        let output = conv_t.forward(&input);
1194        let loss = output.sum();
1195        loss.backward();
1196
1197        assert!(
1198            input.grad().is_some(),
1199            "ConvTranspose2d: input gradient should flow through backward"
1200        );
1201    }
1202}