numrs/ops/
conv.rs

1//! Convolution operations
2//!
3//! Provides 1D, 2D, and 3D convolutions with multi-backend support.
4
5use crate::array::Array;
6use anyhow::{anyhow, Result};
7
8/// 1D Convolution
9///
10/// Applies a 1D convolution over an input signal composed of several input planes.
11///
12/// # Arguments
13/// * `input` - Input tensor of shape [Batch, InChannels, Length]
14/// * `weight` - Filters of shape [OutChannels, InChannels, KernelSize]
15/// * `bias` - Optional bias of shape [OutChannels]
16/// * `stride` - Stride of the convolution. Default: 1
17/// * `padding` - Zero-padding added to both sides of the input. Default: 0
18///
19/// # Returns
20/// Output tensor of shape [Batch, OutChannels, OutLength]
21pub fn conv1d(
22    input: &Array,
23    weight: &Array,
24    bias: Option<&Array>,
25    stride: usize,
26    padding: usize,
27) -> Result<Array> {
28    // 1. Validation
29    if input.shape.len() != 3 {
30        return Err(anyhow!(
31            "Conv1D input must be 3D [Batch, InChannels, Length]"
32        ));
33    }
34    if weight.shape.len() != 3 {
35        return Err(anyhow!(
36            "Conv1D weight must be 3D [OutChannels, InChannels, KernelSize]"
37        ));
38    }
39
40    let in_channels = input.shape[1];
41    let kernel_in_channels = weight.shape[1];
42
43    if in_channels != kernel_in_channels {
44        return Err(anyhow!(
45            "Input channels ({}) mismatch kernel channels ({})",
46            in_channels,
47            kernel_in_channels
48        ));
49    }
50
51    if let Some(b) = bias {
52        if b.shape.len() != 1 {
53            return Err(anyhow!("Bias must be 1D"));
54        }
55        if b.shape[0] != weight.shape[0] {
56            return Err(anyhow!(
57                "Bias size ({}) mismatch output channels ({})",
58                b.shape[0],
59                weight.shape[0]
60            ));
61        }
62    }
63
64    // 2. Dispatch
65    // Check for GPU availability
66    #[cfg(numrs_kernel_conv_gpu)]
67    {
68        if crate::backend::webgpu::is_available_cached() {
69            return crate::backend::webgpu::conv::conv1d_webgpu(
70                input, weight, bias, stride, padding,
71            );
72        }
73    }
74
75    // Check for SIMD (AVX)
76    #[cfg(numrs_kernel_conv_simd)]
77    {
78        // TODO: Validate SIMD support at runtime
79        return crate::backend::cpu::simd::conv1d_simd(input, weight, bias, stride, padding);
80    }
81
82    // Fallback: CPU Naive (Unreachable if SIMD enabled, but kept for clarity/structure if feature disabled)
83    #[cfg(not(numrs_kernel_conv_simd))]
84    crate::backend::cpu::conv::conv1d_naive(input, weight, bias, stride, padding)
85}