ghostflow_nn/
conv.rs

1//! Convolutional layers
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::init;
6
7/// 1D Convolution layer
8pub struct Conv1d {
9    weight: Tensor,
10    bias: Option<Tensor>,
11    in_channels: usize,
12    out_channels: usize,
13    kernel_size: usize,
14    stride: usize,
15    padding: usize,
16    training: bool,
17}
18
19impl Conv1d {
20    pub fn new(
21        in_channels: usize,
22        out_channels: usize,
23        kernel_size: usize,
24        stride: usize,
25        padding: usize,
26    ) -> Self {
27        let fan_in = in_channels * kernel_size;
28        let weight = init::kaiming_uniform(
29            &[out_channels, in_channels, kernel_size],
30            fan_in,
31        );
32        
33        let bound = 1.0 / (fan_in as f32).sqrt();
34        let bias = Some(init::uniform(&[out_channels], -bound, bound));
35
36        Conv1d {
37            weight,
38            bias,
39            in_channels,
40            out_channels,
41            kernel_size,
42            stride,
43            padding,
44            training: true,
45        }
46    }
47}
48
49impl Module for Conv1d {
50    fn forward(&self, input: &Tensor) -> Tensor {
51        // input: [batch, in_channels, length]
52        // weight: [out_channels, in_channels, kernel_size]
53        // output: [batch, out_channels, out_length]
54        
55        let dims = input.dims();
56        let batch = dims[0];
57        let in_len = dims[2];
58        
59        let out_len = (in_len + 2 * self.padding - self.kernel_size) / self.stride + 1;
60        
61        // Simple implementation (not optimized)
62        let input_data = input.data_f32();
63        let weight_data = self.weight.data_f32();
64        
65        let mut output = vec![0.0f32; batch * self.out_channels * out_len];
66        
67        for b in 0..batch {
68            for oc in 0..self.out_channels {
69                for ol in 0..out_len {
70                    let mut sum = 0.0f32;
71                    
72                    for ic in 0..self.in_channels {
73                        for k in 0..self.kernel_size {
74                            let il = ol * self.stride + k;
75                            let il = il as i32 - self.padding as i32;
76                            
77                            if il >= 0 && (il as usize) < in_len {
78                                let input_idx = b * self.in_channels * in_len 
79                                    + ic * in_len + il as usize;
80                                let weight_idx = oc * self.in_channels * self.kernel_size 
81                                    + ic * self.kernel_size + k;
82                                sum += input_data[input_idx] * weight_data[weight_idx];
83                            }
84                        }
85                    }
86                    
87                    let out_idx = b * self.out_channels * out_len + oc * out_len + ol;
88                    output[out_idx] = sum;
89                }
90            }
91        }
92        
93        let mut result = Tensor::from_slice(&output, &[batch, self.out_channels, out_len]).unwrap();
94        
95        if let Some(ref bias) = self.bias {
96            // Add bias (broadcast over batch and length)
97            let bias_data = bias.data_f32();
98            let mut result_data = result.data_f32();
99            
100            for b in 0..batch {
101                for oc in 0..self.out_channels {
102                    for ol in 0..out_len {
103                        let idx = b * self.out_channels * out_len + oc * out_len + ol;
104                        result_data[idx] += bias_data[oc];
105                    }
106                }
107            }
108            
109            result = Tensor::from_slice(&result_data, &[batch, self.out_channels, out_len]).unwrap();
110        }
111        
112        result
113    }
114
115    fn parameters(&self) -> Vec<Tensor> {
116        let mut params = vec![self.weight.clone()];
117        if let Some(ref bias) = self.bias {
118            params.push(bias.clone());
119        }
120        params
121    }
122
123    fn train(&mut self) { self.training = true; }
124    fn eval(&mut self) { self.training = false; }
125    fn is_training(&self) -> bool { self.training }
126}
127
128/// 2D Convolution layer
129pub struct Conv2d {
130    weight: Tensor,
131    bias: Option<Tensor>,
132    in_channels: usize,
133    out_channels: usize,
134    kernel_size: (usize, usize),
135    stride: (usize, usize),
136    padding: (usize, usize),
137    training: bool,
138}
139
140impl Conv2d {
141    pub fn new(
142        in_channels: usize,
143        out_channels: usize,
144        kernel_size: usize,
145        stride: usize,
146        padding: usize,
147    ) -> Self {
148        Self::with_params(
149            in_channels,
150            out_channels,
151            (kernel_size, kernel_size),
152            (stride, stride),
153            (padding, padding),
154        )
155    }
156
157    pub fn with_params(
158        in_channels: usize,
159        out_channels: usize,
160        kernel_size: (usize, usize),
161        stride: (usize, usize),
162        padding: (usize, usize),
163    ) -> Self {
164        let fan_in = in_channels * kernel_size.0 * kernel_size.1;
165        let weight = init::kaiming_uniform(
166            &[out_channels, in_channels, kernel_size.0, kernel_size.1],
167            fan_in,
168        );
169        
170        let bound = 1.0 / (fan_in as f32).sqrt();
171        let bias = Some(init::uniform(&[out_channels], -bound, bound));
172
173        Conv2d {
174            weight,
175            bias,
176            in_channels,
177            out_channels,
178            kernel_size,
179            stride,
180            padding,
181            training: true,
182        }
183    }
184}
185
186impl Module for Conv2d {
187    fn forward(&self, input: &Tensor) -> Tensor {
188        // Use optimized convolution from ghostflow-core
189        #[cfg(feature = "optimized-conv")]
190        {
191            use ghostflow_core::ops::conv::conv2d_optimized;
192            let bias = self.bias.as_ref();
193            return conv2d_optimized(input, &self.weight, bias, self.stride, self.padding).unwrap();
194        }
195        
196        // Fallback to direct implementation
197        #[cfg(not(feature = "optimized-conv"))]
198        {
199            self.forward_direct(input)
200        }
201    }
202
203    fn parameters(&self) -> Vec<Tensor> {
204        let mut params = vec![self.weight.clone()];
205        if let Some(ref bias) = self.bias {
206            params.push(bias.clone());
207        }
208        params
209    }
210
211    fn train(&mut self) { self.training = true; }
212    fn eval(&mut self) { self.training = false; }
213    fn is_training(&self) -> bool { self.training }
214}
215
216#[allow(dead_code)]
217impl Conv2d {
218    /// Direct convolution implementation (fallback)
219    fn forward_direct(&self, input: &Tensor) -> Tensor {
220        // input: [batch, in_channels, height, width]
221        // weight: [out_channels, in_channels, kH, kW]
222        // output: [batch, out_channels, out_height, out_width]
223        
224        let dims = input.dims();
225        let batch = dims[0];
226        let in_h = dims[2];
227        let in_w = dims[3];
228        
229        let out_h = (in_h + 2 * self.padding.0 - self.kernel_size.0) / self.stride.0 + 1;
230        let out_w = (in_w + 2 * self.padding.1 - self.kernel_size.1) / self.stride.1 + 1;
231        
232        let input_data = input.data_f32();
233        let weight_data = self.weight.data_f32();
234        
235        let mut output = vec![0.0f32; batch * self.out_channels * out_h * out_w];
236        
237        // Naive convolution (im2col would be faster)
238        for b in 0..batch {
239            for oc in 0..self.out_channels {
240                for oh in 0..out_h {
241                    for ow in 0..out_w {
242                        let mut sum = 0.0f32;
243                        
244                        for ic in 0..self.in_channels {
245                            for kh in 0..self.kernel_size.0 {
246                                for kw in 0..self.kernel_size.1 {
247                                    let ih = (oh * self.stride.0 + kh) as i32 - self.padding.0 as i32;
248                                    let iw = (ow * self.stride.1 + kw) as i32 - self.padding.1 as i32;
249                                    
250                                    if ih >= 0 && (ih as usize) < in_h && iw >= 0 && (iw as usize) < in_w {
251                                        let input_idx = b * self.in_channels * in_h * in_w
252                                            + ic * in_h * in_w
253                                            + (ih as usize) * in_w
254                                            + iw as usize;
255                                        let weight_idx = oc * self.in_channels * self.kernel_size.0 * self.kernel_size.1
256                                            + ic * self.kernel_size.0 * self.kernel_size.1
257                                            + kh * self.kernel_size.1
258                                            + kw;
259                                        sum += input_data[input_idx] * weight_data[weight_idx];
260                                    }
261                                }
262                            }
263                        }
264                        
265                        let out_idx = b * self.out_channels * out_h * out_w
266                            + oc * out_h * out_w
267                            + oh * out_w
268                            + ow;
269                        output[out_idx] = sum;
270                    }
271                }
272            }
273        }
274        
275        let mut result = Tensor::from_slice(&output, &[batch, self.out_channels, out_h, out_w]).unwrap();
276        
277        if let Some(ref bias) = self.bias {
278            let bias_data = bias.data_f32();
279            let mut result_data = result.data_f32();
280            
281            for b in 0..batch {
282                for oc in 0..self.out_channels {
283                    for oh in 0..out_h {
284                        for ow in 0..out_w {
285                            let idx = b * self.out_channels * out_h * out_w
286                                + oc * out_h * out_w
287                                + oh * out_w
288                                + ow;
289                            result_data[idx] += bias_data[oc];
290                        }
291                    }
292                }
293            }
294            
295            result = Tensor::from_slice(&result_data, &[batch, self.out_channels, out_h, out_w]).unwrap();
296        }
297        
298        result
299    }
300
301    fn parameters(&self) -> Vec<Tensor> {
302        let mut params = vec![self.weight.clone()];
303        if let Some(ref bias) = self.bias {
304            params.push(bias.clone());
305        }
306        params
307    }
308
309    fn train(&mut self) { self.training = true; }
310    fn eval(&mut self) { self.training = false; }
311    fn is_training(&self) -> bool { self.training }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_conv2d_forward() {
320        let conv = Conv2d::new(3, 16, 3, 1, 1);
321        let input = Tensor::randn(&[2, 3, 32, 32]);
322        let output = conv.forward(&input);
323        
324        assert_eq!(output.dims(), &[2, 16, 32, 32]);
325    }
326
327    #[test]
328    fn test_conv2d_stride() {
329        let conv = Conv2d::new(3, 16, 3, 2, 1);
330        let input = Tensor::randn(&[2, 3, 32, 32]);
331        let output = conv.forward(&input);
332        
333        assert_eq!(output.dims(), &[2, 16, 16, 16]);
334    }
335}