axonml_nn/layers/
conv.rs

1//! Convolutional Layers - 1D and 2D Convolutions
2//!
3//! Applies convolution operations over input signals.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{kaiming_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17// =============================================================================
18// Conv1d
19// =============================================================================
20
21/// Applies a 1D convolution over an input signal.
22///
23/// # Shape
24/// - Input: (N, C_in, L)
25/// - Output: (N, C_out, L_out)
26///
27/// where L_out = (L + 2*padding - kernel_size) / stride + 1
28pub struct Conv1d {
29    /// Weight tensor of shape (out_channels, in_channels, kernel_size).
30    pub weight: Parameter,
31    /// Bias tensor of shape (out_channels).
32    pub bias: Option<Parameter>,
33    /// Number of input channels.
34    in_channels: usize,
35    /// Number of output channels.
36    out_channels: usize,
37    /// Size of the convolving kernel.
38    kernel_size: usize,
39    /// Stride of the convolution.
40    stride: usize,
41    /// Zero-padding added to both sides.
42    padding: usize,
43}
44
45impl Conv1d {
46    /// Creates a new Conv1d layer.
47    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
48        Self::with_options(in_channels, out_channels, kernel_size, 1, 0, true)
49    }
50
51    /// Creates a Conv1d layer with all options.
52    pub fn with_options(
53        in_channels: usize,
54        out_channels: usize,
55        kernel_size: usize,
56        stride: usize,
57        padding: usize,
58        bias: bool,
59    ) -> Self {
60        // Initialize weights
61        let fan_in = in_channels * kernel_size;
62        let weight_data = kaiming_uniform(out_channels, fan_in);
63        let weight_reshaped = weight_data
64            .reshape(&[
65                out_channels as isize,
66                in_channels as isize,
67                kernel_size as isize,
68            ])
69            .unwrap();
70        let weight = Parameter::named("weight", weight_reshaped, true);
71
72        let bias_param = if bias {
73            Some(Parameter::named("bias", zeros(&[out_channels]), true))
74        } else {
75            None
76        };
77
78        Self {
79            weight,
80            bias: bias_param,
81            in_channels,
82            out_channels,
83            kernel_size,
84            stride,
85            padding,
86        }
87    }
88}
89
90impl Module for Conv1d {
91    fn forward(&self, input: &Variable) -> Variable {
92        // Basic implementation using im2col approach
93        // For a full implementation, we'd use optimized convolution kernels
94        let input_shape = input.shape();
95        let batch_size = input_shape[0];
96        let _in_channels = input_shape[1];
97        let in_length = input_shape[2];
98
99        // Calculate output length
100        let out_length = (in_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
101
102        // For now, implement a simple direct convolution
103        // A full implementation would use im2col or FFT
104        let input_data = input.data();
105        let weight_data = self.weight.data();
106        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_length];
107
108        for b in 0..batch_size {
109            for oc in 0..self.out_channels {
110                for ol in 0..out_length {
111                    let mut sum = 0.0f32;
112                    let in_start = ol * self.stride;
113
114                    for ic in 0..self.in_channels {
115                        for k in 0..self.kernel_size {
116                            let in_idx = in_start + k;
117                            if in_idx < self.padding || in_idx >= in_length + self.padding {
118                                continue;
119                            }
120                            let actual_idx = in_idx - self.padding;
121
122                            let input_idx =
123                                b * self.in_channels * in_length + ic * in_length + actual_idx;
124                            let weight_idx = oc * self.in_channels * self.kernel_size
125                                + ic * self.kernel_size
126                                + k;
127
128                            sum +=
129                                input_data.to_vec()[input_idx] * weight_data.to_vec()[weight_idx];
130                        }
131                    }
132
133                    // Add bias
134                    if let Some(ref bias) = self.bias {
135                        sum += bias.data().to_vec()[oc];
136                    }
137
138                    let output_idx = b * self.out_channels * out_length + oc * out_length + ol;
139                    output_data[output_idx] = sum;
140                }
141            }
142        }
143
144        let output_tensor =
145            Tensor::from_vec(output_data, &[batch_size, self.out_channels, out_length]).unwrap();
146
147        Variable::new(output_tensor, input.requires_grad())
148    }
149
150    fn parameters(&self) -> Vec<Parameter> {
151        let mut params = vec![self.weight.clone()];
152        if let Some(ref bias) = self.bias {
153            params.push(bias.clone());
154        }
155        params
156    }
157
158    fn named_parameters(&self) -> HashMap<String, Parameter> {
159        let mut params = HashMap::new();
160        params.insert("weight".to_string(), self.weight.clone());
161        if let Some(ref bias) = self.bias {
162            params.insert("bias".to_string(), bias.clone());
163        }
164        params
165    }
166
167    fn name(&self) -> &'static str {
168        "Conv1d"
169    }
170}
171
172// =============================================================================
173// Conv2d
174// =============================================================================
175
176/// Applies a 2D convolution over an input image.
177///
178/// # Shape
179/// - Input: (N, C_in, H, W)
180/// - Output: (N, C_out, H_out, W_out)
181///
182/// where H_out = (H + 2*padding - kernel_size) / stride + 1
183pub struct Conv2d {
184    /// Weight tensor of shape (out_channels, in_channels, kernel_h, kernel_w).
185    pub weight: Parameter,
186    /// Bias tensor of shape (out_channels).
187    pub bias: Option<Parameter>,
188    /// Number of input channels.
189    in_channels: usize,
190    /// Number of output channels.
191    out_channels: usize,
192    /// Size of the convolving kernel (height, width).
193    kernel_size: (usize, usize),
194    /// Stride of the convolution (height, width).
195    stride: (usize, usize),
196    /// Zero-padding added to both sides (height, width).
197    padding: (usize, usize),
198}
199
200impl Conv2d {
201    /// Creates a new Conv2d layer with square kernel.
202    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
203        Self::with_options(
204            in_channels,
205            out_channels,
206            (kernel_size, kernel_size),
207            (1, 1),
208            (0, 0),
209            true,
210        )
211    }
212
213    /// Creates a Conv2d layer with all options.
214    pub fn with_options(
215        in_channels: usize,
216        out_channels: usize,
217        kernel_size: (usize, usize),
218        stride: (usize, usize),
219        padding: (usize, usize),
220        bias: bool,
221    ) -> Self {
222        let (kh, kw) = kernel_size;
223        let fan_in = in_channels * kh * kw;
224
225        // Initialize weights
226        let weight_data = kaiming_uniform(out_channels, fan_in);
227        let weight_reshaped = weight_data
228            .reshape(&[
229                out_channels as isize,
230                in_channels as isize,
231                kh as isize,
232                kw as isize,
233            ])
234            .unwrap();
235        let weight = Parameter::named("weight", weight_reshaped, true);
236
237        let bias_param = if bias {
238            Some(Parameter::named("bias", zeros(&[out_channels]), true))
239        } else {
240            None
241        };
242
243        Self {
244            weight,
245            bias: bias_param,
246            in_channels,
247            out_channels,
248            kernel_size,
249            stride,
250            padding,
251        }
252    }
253}
254
255impl Module for Conv2d {
256    fn forward(&self, input: &Variable) -> Variable {
257        let input_shape = input.shape();
258        let batch_size = input_shape[0];
259        let in_height = input_shape[2];
260        let in_width = input_shape[3];
261
262        let (kh, kw) = self.kernel_size;
263        let (sh, sw) = self.stride;
264        let (ph, pw) = self.padding;
265
266        let out_height = (in_height + 2 * ph - kh) / sh + 1;
267        let out_width = (in_width + 2 * pw - kw) / sw + 1;
268
269        let input_data = input.data();
270        let weight_data = self.weight.data();
271        let input_vec = input_data.to_vec();
272        let weight_vec = weight_data.to_vec();
273
274        let mut output_data = vec![0.0f32; batch_size * self.out_channels * out_height * out_width];
275
276        for b in 0..batch_size {
277            for oc in 0..self.out_channels {
278                for oh in 0..out_height {
279                    for ow in 0..out_width {
280                        let mut sum = 0.0f32;
281
282                        for ic in 0..self.in_channels {
283                            for ki in 0..kh {
284                                for kj in 0..kw {
285                                    let ih = oh * sh + ki;
286                                    let iw = ow * sw + kj;
287
288                                    // Handle padding
289                                    if ih < ph
290                                        || ih >= in_height + ph
291                                        || iw < pw
292                                        || iw >= in_width + pw
293                                    {
294                                        continue;
295                                    }
296
297                                    let actual_ih = ih - ph;
298                                    let actual_iw = iw - pw;
299
300                                    let input_idx = b * self.in_channels * in_height * in_width
301                                        + ic * in_height * in_width
302                                        + actual_ih * in_width
303                                        + actual_iw;
304
305                                    let weight_idx = oc * self.in_channels * kh * kw
306                                        + ic * kh * kw
307                                        + ki * kw
308                                        + kj;
309
310                                    sum += input_vec[input_idx] * weight_vec[weight_idx];
311                                }
312                            }
313                        }
314
315                        // Add bias
316                        if let Some(ref bias) = self.bias {
317                            sum += bias.data().to_vec()[oc];
318                        }
319
320                        let output_idx = b * self.out_channels * out_height * out_width
321                            + oc * out_height * out_width
322                            + oh * out_width
323                            + ow;
324                        output_data[output_idx] = sum;
325                    }
326                }
327            }
328        }
329
330        let output_tensor = Tensor::from_vec(
331            output_data,
332            &[batch_size, self.out_channels, out_height, out_width],
333        )
334        .unwrap();
335
336        Variable::new(output_tensor, input.requires_grad())
337    }
338
339    fn parameters(&self) -> Vec<Parameter> {
340        let mut params = vec![self.weight.clone()];
341        if let Some(ref bias) = self.bias {
342            params.push(bias.clone());
343        }
344        params
345    }
346
347    fn named_parameters(&self) -> HashMap<String, Parameter> {
348        let mut params = HashMap::new();
349        params.insert("weight".to_string(), self.weight.clone());
350        if let Some(ref bias) = self.bias {
351            params.insert("bias".to_string(), bias.clone());
352        }
353        params
354    }
355
356    fn name(&self) -> &'static str {
357        "Conv2d"
358    }
359}
360
361// =============================================================================
362// Tests
363// =============================================================================
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_conv1d_creation() {
371        let conv = Conv1d::new(3, 16, 3);
372        assert_eq!(conv.in_channels, 3);
373        assert_eq!(conv.out_channels, 16);
374        assert_eq!(conv.kernel_size, 3);
375    }
376
377    #[test]
378    fn test_conv1d_forward() {
379        let conv = Conv1d::with_options(1, 1, 3, 1, 1, false);
380        let input = Variable::new(
381            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]).unwrap(),
382            false,
383        );
384        let output = conv.forward(&input);
385        assert_eq!(output.shape(), vec![1, 1, 5]);
386    }
387
388    #[test]
389    fn test_conv2d_creation() {
390        let conv = Conv2d::new(3, 64, 3);
391        assert_eq!(conv.in_channels, 3);
392        assert_eq!(conv.out_channels, 64);
393        assert_eq!(conv.kernel_size, (3, 3));
394    }
395
396    #[test]
397    fn test_conv2d_forward() {
398        let conv = Conv2d::with_options(1, 1, (3, 3), (1, 1), (1, 1), false);
399        let input = Variable::new(
400            Tensor::from_vec(vec![1.0; 25], &[1, 1, 5, 5]).unwrap(),
401            false,
402        );
403        let output = conv.forward(&input);
404        assert_eq!(output.shape(), vec![1, 1, 5, 5]);
405    }
406
407    #[test]
408    fn test_conv2d_parameters() {
409        let conv = Conv2d::new(3, 64, 3);
410        let params = conv.parameters();
411        assert_eq!(params.len(), 2); // weight + bias
412    }
413}