ghostflow_nn/
conv.rs

1//! Convolutional layers
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::init;
6
7/// 1D Convolution layer
8pub struct Conv1d {
9    weight: Tensor,
10    bias: Option<Tensor>,
11    in_channels: usize,
12    out_channels: usize,
13    kernel_size: usize,
14    stride: usize,
15    padding: usize,
16    training: bool,
17}
18
19impl Conv1d {
20    pub fn new(
21        in_channels: usize,
22        out_channels: usize,
23        kernel_size: usize,
24        stride: usize,
25        padding: usize,
26    ) -> Self {
27        let fan_in = in_channels * kernel_size;
28        let weight = init::kaiming_uniform(
29            &[out_channels, in_channels, kernel_size],
30            fan_in,
31        );
32        
33        let bound = 1.0 / (fan_in as f32).sqrt();
34        let bias = Some(init::uniform(&[out_channels], -bound, bound));
35
36        Conv1d {
37            weight,
38            bias,
39            in_channels,
40            out_channels,
41            kernel_size,
42            stride,
43            padding,
44            training: true,
45        }
46    }
47}
48
49impl Module for Conv1d {
50    fn forward(&self, input: &Tensor) -> Tensor {
51        // input: [batch, in_channels, length]
52        // weight: [out_channels, in_channels, kernel_size]
53        // output: [batch, out_channels, out_length]
54        
55        let dims = input.dims();
56        let batch = dims[0];
57        let in_len = dims[2];
58        
59        let out_len = (in_len + 2 * self.padding - self.kernel_size) / self.stride + 1;
60        
61        // Simple implementation (not optimized)
62        let input_data = input.data_f32();
63        let weight_data = self.weight.data_f32();
64        
65        let mut output = vec![0.0f32; batch * self.out_channels * out_len];
66        
67        for b in 0..batch {
68            for oc in 0..self.out_channels {
69                for ol in 0..out_len {
70                    let mut sum = 0.0f32;
71                    
72                    for ic in 0..self.in_channels {
73                        for k in 0..self.kernel_size {
74                            let il = ol * self.stride + k;
75                            let il = il as i32 - self.padding as i32;
76                            
77                            if il >= 0 && (il as usize) < in_len {
78                                let input_idx = b * self.in_channels * in_len 
79                                    + ic * in_len + il as usize;
80                                let weight_idx = oc * self.in_channels * self.kernel_size 
81                                    + ic * self.kernel_size + k;
82                                sum += input_data[input_idx] * weight_data[weight_idx];
83                            }
84                        }
85                    }
86                    
87                    let out_idx = b * self.out_channels * out_len + oc * out_len + ol;
88                    output[out_idx] = sum;
89                }
90            }
91        }
92        
93        let mut result = Tensor::from_slice(&output, &[batch, self.out_channels, out_len]).unwrap();
94        
95        if let Some(ref bias) = self.bias {
96            // Add bias (broadcast over batch and length)
97            let bias_data = bias.data_f32();
98            let mut result_data = result.data_f32();
99            
100            #[allow(clippy::needless_range_loop)]
101            for b in 0..batch {
102                for oc in 0..self.out_channels {
103                    for ol in 0..out_len {
104                        let idx = b * self.out_channels * out_len + oc * out_len + ol;
105                        result_data[idx] += bias_data[oc];
106                    }
107                }
108            }
109            
110            result = Tensor::from_slice(&result_data, &[batch, self.out_channels, out_len]).unwrap();
111        }
112        
113        result
114    }
115
116    fn parameters(&self) -> Vec<Tensor> {
117        let mut params = vec![self.weight.clone()];
118        if let Some(ref bias) = self.bias {
119            params.push(bias.clone());
120        }
121        params
122    }
123
124    fn train(&mut self) { self.training = true; }
125    fn eval(&mut self) { self.training = false; }
126    fn is_training(&self) -> bool { self.training }
127}
128
129/// 2D Convolution layer
130pub struct Conv2d {
131    weight: Tensor,
132    bias: Option<Tensor>,
133    in_channels: usize,
134    out_channels: usize,
135    kernel_size: (usize, usize),
136    stride: (usize, usize),
137    padding: (usize, usize),
138    training: bool,
139}
140
141impl Conv2d {
142    pub fn new(
143        in_channels: usize,
144        out_channels: usize,
145        kernel_size: usize,
146        stride: usize,
147        padding: usize,
148    ) -> Self {
149        Self::with_params(
150            in_channels,
151            out_channels,
152            (kernel_size, kernel_size),
153            (stride, stride),
154            (padding, padding),
155        )
156    }
157
158    pub fn with_params(
159        in_channels: usize,
160        out_channels: usize,
161        kernel_size: (usize, usize),
162        stride: (usize, usize),
163        padding: (usize, usize),
164    ) -> Self {
165        let fan_in = in_channels * kernel_size.0 * kernel_size.1;
166        let weight = init::kaiming_uniform(
167            &[out_channels, in_channels, kernel_size.0, kernel_size.1],
168            fan_in,
169        );
170        
171        let bound = 1.0 / (fan_in as f32).sqrt();
172        let bias = Some(init::uniform(&[out_channels], -bound, bound));
173
174        Conv2d {
175            weight,
176            bias,
177            in_channels,
178            out_channels,
179            kernel_size,
180            stride,
181            padding,
182            training: true,
183        }
184    }
185}
186
187impl Module for Conv2d {
188    fn forward(&self, input: &Tensor) -> Tensor {
189        // Use optimized convolution from ghostflow-core
190        #[cfg(feature = "optimized-conv")]
191        {
192            use ghostflow_core::ops::conv::conv2d_optimized;
193            let bias = self.bias.as_ref();
194            return conv2d_optimized(input, &self.weight, bias, self.stride, self.padding).unwrap();
195        }
196        
197        // Fallback to direct implementation
198        #[cfg(not(feature = "optimized-conv"))]
199        {
200            self.forward_direct(input)
201        }
202    }
203
204    fn parameters(&self) -> Vec<Tensor> {
205        let mut params = vec![self.weight.clone()];
206        if let Some(ref bias) = self.bias {
207            params.push(bias.clone());
208        }
209        params
210    }
211
212    fn train(&mut self) { self.training = true; }
213    fn eval(&mut self) { self.training = false; }
214    fn is_training(&self) -> bool { self.training }
215}
216
217#[allow(dead_code)]
218impl Conv2d {
219    /// Direct convolution implementation (fallback)
220    fn forward_direct(&self, input: &Tensor) -> Tensor {
221        // input: [batch, in_channels, height, width]
222        // weight: [out_channels, in_channels, kH, kW]
223        // output: [batch, out_channels, out_height, out_width]
224        
225        let dims = input.dims();
226        let batch = dims[0];
227        let in_h = dims[2];
228        let in_w = dims[3];
229        
230        let out_h = (in_h + 2 * self.padding.0 - self.kernel_size.0) / self.stride.0 + 1;
231        let out_w = (in_w + 2 * self.padding.1 - self.kernel_size.1) / self.stride.1 + 1;
232        
233        let input_data = input.data_f32();
234        let weight_data = self.weight.data_f32();
235        
236        let mut output = vec![0.0f32; batch * self.out_channels * out_h * out_w];
237        
238        // Naive convolution (im2col would be faster)
239        for b in 0..batch {
240            for oc in 0..self.out_channels {
241                for oh in 0..out_h {
242                    for ow in 0..out_w {
243                        let mut sum = 0.0f32;
244                        
245                        for ic in 0..self.in_channels {
246                            for kh in 0..self.kernel_size.0 {
247                                for kw in 0..self.kernel_size.1 {
248                                    let ih = (oh * self.stride.0 + kh) as i32 - self.padding.0 as i32;
249                                    let iw = (ow * self.stride.1 + kw) as i32 - self.padding.1 as i32;
250                                    
251                                    if ih >= 0 && (ih as usize) < in_h && iw >= 0 && (iw as usize) < in_w {
252                                        let input_idx = b * self.in_channels * in_h * in_w
253                                            + ic * in_h * in_w
254                                            + (ih as usize) * in_w
255                                            + iw as usize;
256                                        let weight_idx = oc * self.in_channels * self.kernel_size.0 * self.kernel_size.1
257                                            + ic * self.kernel_size.0 * self.kernel_size.1
258                                            + kh * self.kernel_size.1
259                                            + kw;
260                                        sum += input_data[input_idx] * weight_data[weight_idx];
261                                    }
262                                }
263                            }
264                        }
265                        
266                        let out_idx = b * self.out_channels * out_h * out_w
267                            + oc * out_h * out_w
268                            + oh * out_w
269                            + ow;
270                        output[out_idx] = sum;
271                    }
272                }
273            }
274        }
275        
276        let mut result = Tensor::from_slice(&output, &[batch, self.out_channels, out_h, out_w]).unwrap();
277        
278        if let Some(ref bias) = self.bias {
279            let bias_data = bias.data_f32();
280            let mut result_data = result.data_f32();
281            
282            #[allow(clippy::needless_range_loop)]
283            for b in 0..batch {
284                for oc in 0..self.out_channels {
285                    for oh in 0..out_h {
286                        for ow in 0..out_w {
287                            let idx = b * self.out_channels * out_h * out_w
288                                + oc * out_h * out_w
289                                + oh * out_w
290                                + ow;
291                            result_data[idx] += bias_data[oc];
292                        }
293                    }
294                }
295            }
296            
297            result = Tensor::from_slice(&result_data, &[batch, self.out_channels, out_h, out_w]).unwrap();
298        }
299        
300        result
301    }
302
303    fn parameters(&self) -> Vec<Tensor> {
304        let mut params = vec![self.weight.clone()];
305        if let Some(ref bias) = self.bias {
306            params.push(bias.clone());
307        }
308        params
309    }
310
311    fn train(&mut self) { self.training = true; }
312    fn eval(&mut self) { self.training = false; }
313    fn is_training(&self) -> bool { self.training }
314}
315
316/// 3D Convolution layer
317pub struct Conv3d {
318    weight: Tensor,
319    bias: Option<Tensor>,
320    in_channels: usize,
321    out_channels: usize,
322    kernel_size: (usize, usize, usize),
323    stride: (usize, usize, usize),
324    padding: (usize, usize, usize),
325    training: bool,
326}
327
328impl Conv3d {
329    pub fn new(
330        in_channels: usize,
331        out_channels: usize,
332        kernel_size: (usize, usize, usize),
333        stride: (usize, usize, usize),
334        padding: (usize, usize, usize),
335    ) -> Self {
336        let (kd, kh, kw) = kernel_size;
337        let fan_in = in_channels * kd * kh * kw;
338        let weight = init::kaiming_uniform(
339            &[out_channels, in_channels, kd, kh, kw],
340            fan_in,
341        );
342        
343        let bound = 1.0 / (fan_in as f32).sqrt();
344        let bias = Some(init::uniform(&[out_channels], -bound, bound));
345
346        Conv3d {
347            weight,
348            bias,
349            in_channels,
350            out_channels,
351            kernel_size,
352            stride,
353            padding,
354            training: true,
355        }
356    }
357}
358
359impl Module for Conv3d {
360    fn forward(&self, input: &Tensor) -> Tensor {
361        // input: [batch, in_channels, depth, height, width]
362        // weight: [out_channels, in_channels, kd, kh, kw]
363        // output: [batch, out_channels, out_depth, out_height, out_width]
364        
365        let dims = input.dims();
366        let batch = dims[0];
367        let in_depth = dims[2];
368        let in_height = dims[3];
369        let in_width = dims[4];
370        
371        let (kd, kh, kw) = self.kernel_size;
372        let (sd, sh, sw) = self.stride;
373        let (pd, ph, pw) = self.padding;
374        
375        let out_depth = (in_depth + 2 * pd - kd) / sd + 1;
376        let out_height = (in_height + 2 * ph - kh) / sh + 1;
377        let out_width = (in_width + 2 * pw - kw) / sw + 1;
378        
379        let input_data = input.data_f32();
380        let weight_data = self.weight.data_f32();
381        
382        let mut output = vec![0.0f32; batch * self.out_channels * out_depth * out_height * out_width];
383        
384        for b in 0..batch {
385            for oc in 0..self.out_channels {
386                for od in 0..out_depth {
387                    for oh in 0..out_height {
388                        for ow in 0..out_width {
389                            let mut sum = 0.0f32;
390                            
391                            for ic in 0..self.in_channels {
392                                for kd_i in 0..kd {
393                                    for kh_i in 0..kh {
394                                        for kw_i in 0..kw {
395                                            let id = od * sd + kd_i;
396                                            let ih = oh * sh + kh_i;
397                                            let iw = ow * sw + kw_i;
398                                            
399                                            let id = id as i32 - pd as i32;
400                                            let ih = ih as i32 - ph as i32;
401                                            let iw = iw as i32 - pw as i32;
402                                            
403                                            if id >= 0 && (id as usize) < in_depth &&
404                                               ih >= 0 && (ih as usize) < in_height &&
405                                               iw >= 0 && (iw as usize) < in_width {
406                                                let input_idx = b * self.in_channels * in_depth * in_height * in_width
407                                                    + ic * in_depth * in_height * in_width
408                                                    + (id as usize) * in_height * in_width
409                                                    + (ih as usize) * in_width
410                                                    + (iw as usize);
411                                                let weight_idx = oc * self.in_channels * kd * kh * kw
412                                                    + ic * kd * kh * kw
413                                                    + kd_i * kh * kw
414                                                    + kh_i * kw
415                                                    + kw_i;
416                                                sum += input_data[input_idx] * weight_data[weight_idx];
417                                            }
418                                        }
419                                    }
420                                }
421                            }
422                            
423                            let out_idx = b * self.out_channels * out_depth * out_height * out_width
424                                + oc * out_depth * out_height * out_width
425                                + od * out_height * out_width
426                                + oh * out_width
427                                + ow;
428                            output[out_idx] = sum;
429                        }
430                    }
431                }
432            }
433        }
434        
435        let result = Tensor::from_slice(&output, &[batch, self.out_channels, out_depth, out_height, out_width]).unwrap();
436        
437        if let Some(ref bias) = self.bias {
438            let bias_data = bias.data_f32();
439            let mut result_data = result.data_f32();
440            
441            for b in 0..batch {
442                for oc in 0..self.out_channels {
443                    for od in 0..out_depth {
444                        for oh in 0..out_height {
445                            for ow in 0..out_width {
446                                let idx = b * self.out_channels * out_depth * out_height * out_width
447                                    + oc * out_depth * out_height * out_width
448                                    + od * out_height * out_width
449                                    + oh * out_width
450                                    + ow;
451                                result_data[idx] += bias_data[oc];
452                            }
453                        }
454                    }
455                }
456            }
457        }
458        
459        result
460    }
461
462    fn parameters(&self) -> Vec<Tensor> {
463        let mut params = vec![self.weight.clone()];
464        if let Some(ref bias) = self.bias {
465            params.push(bias.clone());
466        }
467        params
468    }
469
470    fn train(&mut self) { self.training = true; }
471    fn eval(&mut self) { self.training = false; }
472    fn is_training(&self) -> bool { self.training }
473}
474
475/// Transpose 2D Convolution layer (Deconvolution)
476pub struct TransposeConv2d {
477    weight: Tensor,
478    bias: Option<Tensor>,
479    in_channels: usize,
480    out_channels: usize,
481    kernel_size: (usize, usize),
482    stride: (usize, usize),
483    padding: (usize, usize),
484    output_padding: (usize, usize),
485    training: bool,
486}
487
488impl TransposeConv2d {
489    pub fn new(
490        in_channels: usize,
491        out_channels: usize,
492        kernel_size: (usize, usize),
493        stride: (usize, usize),
494        padding: (usize, usize),
495        output_padding: (usize, usize),
496    ) -> Self {
497        let (kh, kw) = kernel_size;
498        let fan_in = in_channels * kh * kw;
499        // Note: weight shape is [in_channels, out_channels, kh, kw] for transpose conv
500        let weight = init::kaiming_uniform(
501            &[in_channels, out_channels, kh, kw],
502            fan_in,
503        );
504        
505        let bound = 1.0 / (fan_in as f32).sqrt();
506        let bias = Some(init::uniform(&[out_channels], -bound, bound));
507
508        TransposeConv2d {
509            weight,
510            bias,
511            in_channels,
512            out_channels,
513            kernel_size,
514            stride,
515            padding,
516            output_padding,
517            training: true,
518        }
519    }
520}
521
522impl Module for TransposeConv2d {
523    fn forward(&self, input: &Tensor) -> Tensor {
524        // input: [batch, in_channels, height, width]
525        // weight: [in_channels, out_channels, kh, kw]
526        // output: [batch, out_channels, out_height, out_width]
527        
528        let dims = input.dims();
529        let batch = dims[0];
530        let in_height = dims[2];
531        let in_width = dims[3];
532        
533        let (kh, kw) = self.kernel_size;
534        let (sh, sw) = self.stride;
535        let (ph, pw) = self.padding;
536        let (oph, opw) = self.output_padding;
537        
538        // Calculate output dimensions for transpose convolution
539        let out_height = (in_height - 1) * sh - 2 * ph + kh + oph;
540        let out_width = (in_width - 1) * sw - 2 * pw + kw + opw;
541        
542        let input_data = input.data_f32();
543        let weight_data = self.weight.data_f32();
544        
545        let mut output = vec![0.0f32; batch * self.out_channels * out_height * out_width];
546        
547        for b in 0..batch {
548            for ic in 0..self.in_channels {
549                for ih in 0..in_height {
550                    for iw in 0..in_width {
551                        let input_idx = b * self.in_channels * in_height * in_width
552                            + ic * in_height * in_width
553                            + ih * in_width
554                            + iw;
555                        let input_val = input_data[input_idx];
556                        
557                        for oc in 0..self.out_channels {
558                            for kh_i in 0..kh {
559                                for kw_i in 0..kw {
560                                    let oh = ih * sh + kh_i;
561                                    let ow = iw * sw + kw_i;
562                                    
563                                    let oh = oh as i32 - ph as i32;
564                                    let ow = ow as i32 - pw as i32;
565                                    
566                                    if oh >= 0 && (oh as usize) < out_height &&
567                                       ow >= 0 && (ow as usize) < out_width {
568                                        let weight_idx = ic * self.out_channels * kh * kw
569                                            + oc * kh * kw
570                                            + kh_i * kw
571                                            + kw_i;
572                                        let out_idx = b * self.out_channels * out_height * out_width
573                                            + oc * out_height * out_width
574                                            + (oh as usize) * out_width
575                                            + (ow as usize);
576                                        output[out_idx] += input_val * weight_data[weight_idx];
577                                    }
578                                }
579                            }
580                        }
581                    }
582                }
583            }
584        }
585        
586        let result = Tensor::from_slice(&output, &[batch, self.out_channels, out_height, out_width]).unwrap();
587        
588        if let Some(ref bias) = self.bias {
589            let bias_data = bias.data_f32();
590            let mut result_data = result.data_f32();
591            
592            for b in 0..batch {
593                for oc in 0..self.out_channels {
594                    for oh in 0..out_height {
595                        for ow in 0..out_width {
596                            let idx = b * self.out_channels * out_height * out_width
597                                + oc * out_height * out_width
598                                + oh * out_width
599                                + ow;
600                            result_data[idx] += bias_data[oc];
601                        }
602                    }
603                }
604            }
605        }
606        
607        result
608    }
609
610    fn parameters(&self) -> Vec<Tensor> {
611        let mut params = vec![self.weight.clone()];
612        if let Some(ref bias) = self.bias {
613            params.push(bias.clone());
614        }
615        params
616    }
617
618    fn train(&mut self) { self.training = true; }
619    fn eval(&mut self) { self.training = false; }
620    fn is_training(&self) -> bool { self.training }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626
627    #[test]
628    fn test_conv2d_forward() {
629        let conv = Conv2d::new(3, 16, 3, 1, 1);
630        let input = Tensor::randn(&[2, 3, 32, 32]);
631        let output = conv.forward(&input);
632        
633        assert_eq!(output.dims(), &[2, 16, 32, 32]);
634    }
635
636    #[test]
637    fn test_conv2d_stride() {
638        let conv = Conv2d::new(3, 16, 3, 2, 1);
639        let input = Tensor::randn(&[2, 3, 32, 32]);
640        let output = conv.forward(&input);
641        
642        assert_eq!(output.dims(), &[2, 16, 16, 16]);
643    }
644}