avila_math/tensor/
conv4d.rs

1/// Convolução 4D para Redes Neurais e Processamento Espaço-Temporal
2///
3/// Implementa operações de convolução 4D otimizadas para:
4/// - Dados espaço-temporais (x, y, z, t)
5/// - Redes neurais convolucionais 4D
6/// - Processamento de vídeos 3D + tempo
7/// - Simulações físicas 4D
8///
9/// Formato de tensores: [batch, channels, dim1, dim2, dim3, dim4]
10use crate::tensor::Tensor;
11use rayon::prelude::*;
12
13pub type Tensor4D = Tensor<4>;
14pub type Tensor5D = Tensor<5>;
15pub type Tensor6D = Tensor<6>;
16
17/// Configuração de convolução 4D
18#[derive(Debug, Clone, Copy)]
19pub struct Conv4DConfig {
20    /// Stride em cada dimensão [s1, s2, s3, s4]
21    pub stride: [usize; 4],
22    /// Padding em cada dimensão [p1, p2, p3, p4]
23    pub padding: [usize; 4],
24    /// Dilation em cada dimensão [d1, d2, d3, d4]
25    pub dilation: [usize; 4],
26    /// Grupos de convolução (para grouped convolution)
27    pub groups: usize,
28}
29
30impl Default for Conv4DConfig {
31    fn default() -> Self {
32        Self {
33            stride: [1, 1, 1, 1],
34            padding: [0, 0, 0, 0],
35            dilation: [1, 1, 1, 1],
36            groups: 1,
37        }
38    }
39}
40
41impl Conv4DConfig {
42    pub fn with_stride(mut self, stride: [usize; 4]) -> Self {
43        self.stride = stride;
44        self
45    }
46
47    pub fn with_padding(mut self, padding: [usize; 4]) -> Self {
48        self.padding = padding;
49        self
50    }
51
52    pub fn with_dilation(mut self, dilation: [usize; 4]) -> Self {
53        self.dilation = dilation;
54        self
55    }
56
57    pub fn with_groups(mut self, groups: usize) -> Self {
58        self.groups = groups;
59        self
60    }
61
62    /// Calcula as dimensões de saída
63    pub fn output_size(&self, input_size: [usize; 4], kernel_size: [usize; 4]) -> [usize; 4] {
64        [
65            (input_size[0] + 2 * self.padding[0] - self.dilation[0] * (kernel_size[0] - 1) - 1)
66                / self.stride[0]
67                + 1,
68            (input_size[1] + 2 * self.padding[1] - self.dilation[1] * (kernel_size[1] - 1) - 1)
69                / self.stride[1]
70                + 1,
71            (input_size[2] + 2 * self.padding[2] - self.dilation[2] * (kernel_size[2] - 1) - 1)
72                / self.stride[2]
73                + 1,
74            (input_size[3] + 2 * self.padding[3] - self.dilation[3] * (kernel_size[3] - 1) - 1)
75                / self.stride[3]
76                + 1,
77        ]
78    }
79}
80
81/// Layer de Convolução 4D
82///
83/// Formato:
84/// - Input: [batch, in_channels, d1, d2, d3, d4]
85/// - Kernel: [out_channels, in_channels/groups, k1, k2, k3, k4]
86/// - Bias: [out_channels]
87/// - Output: [batch, out_channels, o1, o2, o3, o4]
88pub struct Conv4DLayer {
89    /// Pesos do kernel
90    pub weights: Tensor6D,
91    /// Bias (opcional)
92    pub bias: Option<Tensor<1>>,
93    /// Configuração
94    pub config: Conv4DConfig,
95}
96
97impl Conv4DLayer {
98    /// Cria nova camada de convolução 4D
99    ///
100    /// # Arguments
101    /// * `in_channels` - Número de canais de entrada
102    /// * `out_channels` - Número de canais de saída
103    /// * `kernel_size` - Tamanho do kernel [k1, k2, k3, k4]
104    /// * `config` - Configuração de convolução
105    pub fn new(
106        in_channels: usize,
107        out_channels: usize,
108        kernel_size: [usize; 4],
109        config: Conv4DConfig,
110    ) -> Self {
111        let channels_per_group = in_channels / config.groups;
112        let weights = Tensor6D::zeros([
113            out_channels,
114            channels_per_group,
115            kernel_size[0],
116            kernel_size[1],
117            kernel_size[2],
118            kernel_size[3],
119        ]);
120
121        Self {
122            weights,
123            bias: None,
124            config,
125        }
126    }
127
128    /// Adiciona bias à camada
129    pub fn with_bias(mut self, out_channels: usize) -> Self {
130        self.bias = Some(Tensor::<1>::zeros([out_channels]));
131        self
132    }
133
134    /// Inicializa pesos com Xavier/Glorot
135    pub fn init_xavier(&mut self) {
136        let fan_in = self.weights.shape[1]
137            * self.weights.shape[2]
138            * self.weights.shape[3]
139            * self.weights.shape[4]
140            * self.weights.shape[5];
141        let fan_out = self.weights.shape[0]
142            * self.weights.shape[2]
143            * self.weights.shape[3]
144            * self.weights.shape[4]
145            * self.weights.shape[5];
146        let scale = (2.0 / (fan_in + fan_out) as f64).sqrt();
147
148        // Inicializa com distribuição uniforme [-scale, scale]
149        for i in 0..self.weights.data.len() {
150            self.weights.data[i] = (rand::random::<f64>() * 2.0 - 1.0) * scale;
151        }
152    }
153
154    /// Inicializa pesos com He initialization (para ReLU)
155    pub fn init_he(&mut self) {
156        let fan_in = self.weights.shape[1]
157            * self.weights.shape[2]
158            * self.weights.shape[3]
159            * self.weights.shape[4]
160            * self.weights.shape[5];
161        let scale = (2.0 / fan_in as f64).sqrt();
162
163        for i in 0..self.weights.data.len() {
164            self.weights.data[i] = rand::random::<f64>() * scale;
165        }
166    }
167
168    /// Forward pass - aplica convolução 4D
169    pub fn forward(&self, input: &Tensor6D) -> Result<Tensor6D, String> {
170        conv4d(input, &self.weights, self.bias.as_ref(), &self.config)
171    }
172
173    /// Backward pass - calcula gradientes
174    ///
175    /// # Arguments
176    /// * `input` - Tensor de entrada original [batch, in_channels, d1, d2, d3, d4]
177    /// * `grad_output` - Gradiente da loss em relação à saída [batch, out_channels, o1, o2, o3, o4]
178    ///
179    /// # Returns
180    /// * `grad_input` - Gradiente em relação à entrada
181    /// * `grad_weights` - Gradiente em relação aos pesos
182    /// * `grad_bias` - Gradiente em relação ao bias (se existir)
183    pub fn backward(
184        &self,
185        input: &Tensor6D,
186        grad_output: &Tensor6D,
187    ) -> Result<(Tensor6D, Tensor6D, Option<Tensor<1>>), String> {
188        let batch = input.shape[0];
189        let in_channels = input.shape[1];
190        let input_size = [
191            input.shape[2],
192            input.shape[3],
193            input.shape[4],
194            input.shape[5],
195        ];
196
197        let out_channels = grad_output.shape[1];
198        let output_size = [
199            grad_output.shape[2],
200            grad_output.shape[3],
201            grad_output.shape[4],
202            grad_output.shape[5],
203        ];
204
205        let kernel_size = [
206            self.weights.shape[2],
207            self.weights.shape[3],
208            self.weights.shape[4],
209            self.weights.shape[5],
210        ];
211
212        // Gradiente em relação à entrada
213        let mut grad_input = Tensor6D::zeros(input.shape);
214
215        // Gradiente em relação aos pesos
216        let mut grad_weights = Tensor6D::zeros(self.weights.shape);
217
218        // Gradiente em relação ao bias
219        let grad_bias = if self.bias.is_some() {
220            Some(Tensor::<1>::zeros([out_channels]))
221        } else {
222            None
223        };
224
225        let channels_per_group = in_channels / self.config.groups;
226
227        // Calcula gradientes em paralelo por batch
228        let grad_results: Vec<_> = (0..batch)
229            .into_par_iter()
230            .map(|b| {
231                conv4d_backward_single_batch(
232                    input,
233                    &self.weights,
234                    grad_output,
235                    b,
236                    in_channels,
237                    out_channels,
238                    &input_size,
239                    &kernel_size,
240                    &output_size,
241                    channels_per_group,
242                    &self.config,
243                )
244            })
245            .collect();
246
247        // Acumula gradientes de todos os batches
248        for (b, (grad_in_batch, grad_w_batch)) in grad_results.into_iter().enumerate() {
249            // Acumula grad_input
250            for ic in 0..in_channels {
251                for i1 in 0..input_size[0] {
252                    for i2 in 0..input_size[1] {
253                        for i3 in 0..input_size[2] {
254                            for i4 in 0..input_size[3] {
255                                let idx = ic
256                                    * input_size[0]
257                                    * input_size[1]
258                                    * input_size[2]
259                                    * input_size[3]
260                                    + i1 * input_size[1] * input_size[2] * input_size[3]
261                                    + i2 * input_size[2] * input_size[3]
262                                    + i3 * input_size[3]
263                                    + i4;
264                                let current = grad_input.get([b, ic, i1, i2, i3, i4]).unwrap();
265                                grad_input
266                                    .set([b, ic, i1, i2, i3, i4], current + grad_in_batch[idx])
267                                    .unwrap();
268                            }
269                        }
270                    }
271                }
272            }
273
274            // Acumula grad_weights
275            for oc in 0..out_channels {
276                for ic in 0..channels_per_group {
277                    for k1 in 0..kernel_size[0] {
278                        for k2 in 0..kernel_size[1] {
279                            for k3 in 0..kernel_size[2] {
280                                for k4 in 0..kernel_size[3] {
281                                    let idx = oc
282                                        * channels_per_group
283                                        * kernel_size[0]
284                                        * kernel_size[1]
285                                        * kernel_size[2]
286                                        * kernel_size[3]
287                                        + ic * kernel_size[0]
288                                            * kernel_size[1]
289                                            * kernel_size[2]
290                                            * kernel_size[3]
291                                        + k1 * kernel_size[1] * kernel_size[2] * kernel_size[3]
292                                        + k2 * kernel_size[2] * kernel_size[3]
293                                        + k3 * kernel_size[3]
294                                        + k4;
295                                    let current =
296                                        grad_weights.get([oc, ic, k1, k2, k3, k4]).unwrap();
297                                    grad_weights
298                                        .set([oc, ic, k1, k2, k3, k4], current + grad_w_batch[idx])
299                                        .unwrap();
300                                }
301                            }
302                        }
303                    }
304                }
305            }
306        }
307
308        // Calcula gradiente do bias (soma sobre batch e dimensões espaciais)
309        if let Some(gb) = grad_bias.as_ref() {
310            let mut new_grad_bias = gb.clone();
311            for oc in 0..out_channels {
312                let mut sum = 0.0;
313                for b in 0..batch {
314                    for o1 in 0..output_size[0] {
315                        for o2 in 0..output_size[1] {
316                            for o3 in 0..output_size[2] {
317                                for o4 in 0..output_size[3] {
318                                    sum += grad_output.get([b, oc, o1, o2, o3, o4]).unwrap();
319                                }
320                            }
321                        }
322                    }
323                }
324                new_grad_bias.set([oc], sum).unwrap();
325            }
326            return Ok((grad_input, grad_weights, Some(new_grad_bias)));
327        }
328
329        Ok((grad_input, grad_weights, grad_bias))
330    }
331}
332
333/// Convolução 4D completa com suporte a batch e múltiplos canais
334///
335/// # Arguments
336/// * `input` - Tensor de entrada [batch, in_channels, d1, d2, d3, d4]
337/// * `kernel` - Kernel [out_channels, in_channels/groups, k1, k2, k3, k4]
338/// * `bias` - Bias opcional [out_channels]
339/// * `config` - Configuração de convolução
340pub fn conv4d(
341    input: &Tensor6D,
342    kernel: &Tensor6D,
343    bias: Option<&Tensor<1>>,
344    config: &Conv4DConfig,
345) -> Result<Tensor6D, String> {
346    // Validações
347    let batch = input.shape[0];
348    let in_channels = input.shape[1];
349    let input_size = [
350        input.shape[2],
351        input.shape[3],
352        input.shape[4],
353        input.shape[5],
354    ];
355
356    let out_channels = kernel.shape[0];
357    let kernel_channels = kernel.shape[1];
358    let kernel_size = [
359        kernel.shape[2],
360        kernel.shape[3],
361        kernel.shape[4],
362        kernel.shape[5],
363    ];
364
365    if !in_channels.is_multiple_of(config.groups) {
366        return Err("in_channels deve ser divisível por groups".to_string());
367    }
368
369    if !out_channels.is_multiple_of(config.groups) {
370        return Err("out_channels deve ser divisível por groups".to_string());
371    }
372
373    if kernel_channels != in_channels / config.groups {
374        return Err(format!(
375            "kernel channels ({}) deve ser igual a in_channels/groups ({})",
376            kernel_channels,
377            in_channels / config.groups
378        ));
379    }
380
381    // Calcula dimensões de saída
382    let output_size = config.output_size(input_size, kernel_size);
383
384    // Cria tensor de saída
385    let mut output = Tensor6D::zeros([
386        batch,
387        out_channels,
388        output_size[0],
389        output_size[1],
390        output_size[2],
391        output_size[3],
392    ]);
393
394    // Convolução paralela por batch
395    let results: Vec<_> = (0..batch)
396        .into_par_iter()
397        .map(|b| {
398            conv4d_single_batch(
399                input,
400                kernel,
401                b,
402                in_channels,
403                out_channels,
404                &input_size,
405                &kernel_size,
406                &output_size,
407                config,
408            )
409        })
410        .collect();
411
412    // Copia resultados para o tensor de saída
413    for (b, batch_data) in results.into_iter().enumerate() {
414        for oc in 0..out_channels {
415            for o1 in 0..output_size[0] {
416                for o2 in 0..output_size[1] {
417                    for o3 in 0..output_size[2] {
418                        for o4 in 0..output_size[3] {
419                            let idx = oc
420                                * output_size[0]
421                                * output_size[1]
422                                * output_size[2]
423                                * output_size[3]
424                                + o1 * output_size[1] * output_size[2] * output_size[3]
425                                + o2 * output_size[2] * output_size[3]
426                                + o3 * output_size[3]
427                                + o4;
428                            output
429                                .set(
430                                    [b, oc, o1, o2, o3, o4],
431                                    batch_data[idx] + bias.map_or(0.0, |b| b.get([oc]).unwrap()),
432                                )
433                                .unwrap();
434                        }
435                    }
436                }
437            }
438        }
439    }
440
441    Ok(output)
442}
443
444/// Convolução 4D para um único item do batch (paralelizável)
445#[allow(clippy::too_many_arguments)]
446fn conv4d_single_batch(
447    input: &Tensor6D,
448    kernel: &Tensor6D,
449    batch_idx: usize,
450    in_channels: usize,
451    out_channels: usize,
452    input_size: &[usize; 4],
453    kernel_size: &[usize; 4],
454    output_size: &[usize; 4],
455    config: &Conv4DConfig,
456) -> Vec<f64> {
457    let mut result =
458        vec![0.0; out_channels * output_size[0] * output_size[1] * output_size[2] * output_size[3]];
459
460    let channels_per_group = in_channels / config.groups;
461
462    // Para cada canal de saída
463    for oc in 0..out_channels {
464        let group = oc / (out_channels / config.groups);
465        let group_start = group * channels_per_group;
466        let group_end = group_start + channels_per_group;
467
468        // Para cada posição de saída
469        for o1 in 0..output_size[0] {
470            for o2 in 0..output_size[1] {
471                for o3 in 0..output_size[2] {
472                    for o4 in 0..output_size[3] {
473                        let mut sum = 0.0;
474
475                        // Para cada canal de entrada no grupo
476                        for ic in group_start..group_end {
477                            // Para cada posição do kernel
478                            for k1 in 0..kernel_size[0] {
479                                for k2 in 0..kernel_size[1] {
480                                    for k3 in 0..kernel_size[2] {
481                                        for k4 in 0..kernel_size[3] {
482                                            // Calcula posição na entrada com stride e dilation
483                                            let i1 =
484                                                o1 * config.stride[0] + k1 * config.dilation[0];
485                                            let i2 =
486                                                o2 * config.stride[1] + k2 * config.dilation[1];
487                                            let i3 =
488                                                o3 * config.stride[2] + k3 * config.dilation[2];
489                                            let i4 =
490                                                o4 * config.stride[3] + k4 * config.dilation[3];
491
492                                            // Aplica padding (assume zero-padding)
493                                            if i1 >= config.padding[0]
494                                                && i2 >= config.padding[1]
495                                                && i3 >= config.padding[2]
496                                                && i4 >= config.padding[3]
497                                            {
498                                                let i1 = i1 - config.padding[0];
499                                                let i2 = i2 - config.padding[1];
500                                                let i3 = i3 - config.padding[2];
501                                                let i4 = i4 - config.padding[3];
502
503                                                if i1 < input_size[0]
504                                                    && i2 < input_size[1]
505                                                    && i3 < input_size[2]
506                                                    && i4 < input_size[3]
507                                                {
508                                                    let input_val = input
509                                                        .get([batch_idx, ic, i1, i2, i3, i4])
510                                                        .unwrap();
511                                                    let kernel_val = kernel
512                                                        .get([oc, ic - group_start, k1, k2, k3, k4])
513                                                        .unwrap();
514                                                    sum += input_val * kernel_val;
515                                                }
516                                            }
517                                        }
518                                    }
519                                }
520                            }
521                        }
522
523                        let idx =
524                            oc * output_size[0] * output_size[1] * output_size[2] * output_size[3]
525                                + o1 * output_size[1] * output_size[2] * output_size[3]
526                                + o2 * output_size[2] * output_size[3]
527                                + o3 * output_size[3]
528                                + o4;
529                        result[idx] = sum;
530                    }
531                }
532            }
533        }
534    }
535
536    result
537}
538
539/// Backward pass para um único item do batch (paralelizável)
540#[allow(clippy::too_many_arguments)]
541fn conv4d_backward_single_batch(
542    input: &Tensor6D,
543    weights: &Tensor6D,
544    grad_output: &Tensor6D,
545    batch_idx: usize,
546    in_channels: usize,
547    out_channels: usize,
548    input_size: &[usize; 4],
549    kernel_size: &[usize; 4],
550    output_size: &[usize; 4],
551    channels_per_group: usize,
552    config: &Conv4DConfig,
553) -> (Vec<f64>, Vec<f64>) {
554    let grad_input_size =
555        in_channels * input_size[0] * input_size[1] * input_size[2] * input_size[3];
556    let grad_weights_size = out_channels
557        * channels_per_group
558        * kernel_size[0]
559        * kernel_size[1]
560        * kernel_size[2]
561        * kernel_size[3];
562
563    let mut grad_input = vec![0.0; grad_input_size];
564    let mut grad_weights = vec![0.0; grad_weights_size];
565
566    // Para cada canal de saída
567    for oc in 0..out_channels {
568        let group = oc / (out_channels / config.groups);
569        let group_start = group * channels_per_group;
570        let group_end = group_start + channels_per_group;
571
572        // Para cada posição de saída
573        for o1 in 0..output_size[0] {
574            for o2 in 0..output_size[1] {
575                for o3 in 0..output_size[2] {
576                    for o4 in 0..output_size[3] {
577                        let grad_out_val =
578                            grad_output.get([batch_idx, oc, o1, o2, o3, o4]).unwrap();
579
580                        // Para cada canal de entrada no grupo
581                        for ic in group_start..group_end {
582                            // Para cada posição do kernel
583                            for k1 in 0..kernel_size[0] {
584                                for k2 in 0..kernel_size[1] {
585                                    for k3 in 0..kernel_size[2] {
586                                        for k4 in 0..kernel_size[3] {
587                                            // Calcula posição na entrada
588                                            let i1 =
589                                                o1 * config.stride[0] + k1 * config.dilation[0];
590                                            let i2 =
591                                                o2 * config.stride[1] + k2 * config.dilation[1];
592                                            let i3 =
593                                                o3 * config.stride[2] + k3 * config.dilation[2];
594                                            let i4 =
595                                                o4 * config.stride[3] + k4 * config.dilation[3];
596
597                                            // Verifica bounds com padding
598                                            if i1 >= config.padding[0]
599                                                && i2 >= config.padding[1]
600                                                && i3 >= config.padding[2]
601                                                && i4 >= config.padding[3]
602                                            {
603                                                let i1 = i1 - config.padding[0];
604                                                let i2 = i2 - config.padding[1];
605                                                let i3 = i3 - config.padding[2];
606                                                let i4 = i4 - config.padding[3];
607
608                                                if i1 < input_size[0]
609                                                    && i2 < input_size[1]
610                                                    && i3 < input_size[2]
611                                                    && i4 < input_size[3]
612                                                {
613                                                    // Gradiente em relação à entrada
614                                                    let weight_val = weights
615                                                        .get([oc, ic - group_start, k1, k2, k3, k4])
616                                                        .unwrap();
617                                                    let grad_in_idx = ic
618                                                        * input_size[0]
619                                                        * input_size[1]
620                                                        * input_size[2]
621                                                        * input_size[3]
622                                                        + i1 * input_size[1]
623                                                            * input_size[2]
624                                                            * input_size[3]
625                                                        + i2 * input_size[2] * input_size[3]
626                                                        + i3 * input_size[3]
627                                                        + i4;
628                                                    grad_input[grad_in_idx] +=
629                                                        grad_out_val * weight_val;
630
631                                                    // Gradiente em relação aos pesos
632                                                    let input_val = input
633                                                        .get([batch_idx, ic, i1, i2, i3, i4])
634                                                        .unwrap();
635                                                    let grad_w_idx = oc
636                                                        * channels_per_group
637                                                        * kernel_size[0]
638                                                        * kernel_size[1]
639                                                        * kernel_size[2]
640                                                        * kernel_size[3]
641                                                        + (ic - group_start)
642                                                            * kernel_size[0]
643                                                            * kernel_size[1]
644                                                            * kernel_size[2]
645                                                            * kernel_size[3]
646                                                        + k1 * kernel_size[1]
647                                                            * kernel_size[2]
648                                                            * kernel_size[3]
649                                                        + k2 * kernel_size[2] * kernel_size[3]
650                                                        + k3 * kernel_size[3]
651                                                        + k4;
652                                                    grad_weights[grad_w_idx] +=
653                                                        grad_out_val * input_val;
654                                                }
655                                            }
656                                        }
657                                    }
658                                }
659                            }
660                        }
661                    }
662                }
663            }
664        }
665    }
666
667    (grad_input, grad_weights)
668}
669
670/// Max Pooling 4D
671///
672/// # Arguments
673/// * `input` - Tensor de entrada [batch, channels, d1, d2, d3, d4]
674/// * `kernel_size` - Tamanho do kernel [k1, k2, k3, k4]
675/// * `stride` - Stride [s1, s2, s3, s4] (se None, usa kernel_size)
676pub fn max_pool4d(
677    input: &Tensor6D,
678    kernel_size: [usize; 4],
679    stride: Option<[usize; 4]>,
680) -> Result<Tensor6D, String> {
681    let stride = stride.unwrap_or(kernel_size);
682
683    let batch = input.shape[0];
684    let channels = input.shape[1];
685    let input_size = [
686        input.shape[2],
687        input.shape[3],
688        input.shape[4],
689        input.shape[5],
690    ];
691
692    let output_size = [
693        (input_size[0] - kernel_size[0]) / stride[0] + 1,
694        (input_size[1] - kernel_size[1]) / stride[1] + 1,
695        (input_size[2] - kernel_size[2]) / stride[2] + 1,
696        (input_size[3] - kernel_size[3]) / stride[3] + 1,
697    ];
698
699    let mut output = Tensor6D::zeros([
700        batch,
701        channels,
702        output_size[0],
703        output_size[1],
704        output_size[2],
705        output_size[3],
706    ]);
707
708    for b in 0..batch {
709        for c in 0..channels {
710            for o1 in 0..output_size[0] {
711                for o2 in 0..output_size[1] {
712                    for o3 in 0..output_size[2] {
713                        for o4 in 0..output_size[3] {
714                            let mut max_val = f64::NEG_INFINITY;
715
716                            for k1 in 0..kernel_size[0] {
717                                for k2 in 0..kernel_size[1] {
718                                    for k3 in 0..kernel_size[2] {
719                                        for k4 in 0..kernel_size[3] {
720                                            let i1 = o1 * stride[0] + k1;
721                                            let i2 = o2 * stride[1] + k2;
722                                            let i3 = o3 * stride[2] + k3;
723                                            let i4 = o4 * stride[3] + k4;
724
725                                            let val = input.get([b, c, i1, i2, i3, i4]).unwrap();
726                                            if val > max_val {
727                                                max_val = val;
728                                            }
729                                        }
730                                    }
731                                }
732                            }
733
734                            output.set([b, c, o1, o2, o3, o4], max_val).unwrap();
735                        }
736                    }
737                }
738            }
739        }
740    }
741
742    Ok(output)
743}
744
745/// Average Pooling 4D
746pub fn avg_pool4d(
747    input: &Tensor6D,
748    kernel_size: [usize; 4],
749    stride: Option<[usize; 4]>,
750) -> Result<Tensor6D, String> {
751    let stride = stride.unwrap_or(kernel_size);
752
753    let batch = input.shape[0];
754    let channels = input.shape[1];
755    let input_size = [
756        input.shape[2],
757        input.shape[3],
758        input.shape[4],
759        input.shape[5],
760    ];
761
762    let output_size = [
763        (input_size[0] - kernel_size[0]) / stride[0] + 1,
764        (input_size[1] - kernel_size[1]) / stride[1] + 1,
765        (input_size[2] - kernel_size[2]) / stride[2] + 1,
766        (input_size[3] - kernel_size[3]) / stride[3] + 1,
767    ];
768
769    let mut output = Tensor6D::zeros([
770        batch,
771        channels,
772        output_size[0],
773        output_size[1],
774        output_size[2],
775        output_size[3],
776    ]);
777
778    let kernel_vol = (kernel_size[0] * kernel_size[1] * kernel_size[2] * kernel_size[3]) as f64;
779
780    for b in 0..batch {
781        for c in 0..channels {
782            for o1 in 0..output_size[0] {
783                for o2 in 0..output_size[1] {
784                    for o3 in 0..output_size[2] {
785                        for o4 in 0..output_size[3] {
786                            let mut sum = 0.0;
787
788                            for k1 in 0..kernel_size[0] {
789                                for k2 in 0..kernel_size[1] {
790                                    for k3 in 0..kernel_size[2] {
791                                        for k4 in 0..kernel_size[3] {
792                                            let i1 = o1 * stride[0] + k1;
793                                            let i2 = o2 * stride[1] + k2;
794                                            let i3 = o3 * stride[2] + k3;
795                                            let i4 = o4 * stride[3] + k4;
796
797                                            sum += input.get([b, c, i1, i2, i3, i4]).unwrap();
798                                        }
799                                    }
800                                }
801                            }
802
803                            output
804                                .set([b, c, o1, o2, o3, o4], sum / kernel_vol)
805                                .unwrap();
806                        }
807                    }
808                }
809            }
810        }
811    }
812
813    Ok(output)
814}
815
816#[cfg(test)]
817mod tests {
818    use super::*;
819
820    #[test]
821    fn test_conv4d_config() {
822        let config = Conv4DConfig::default()
823            .with_stride([2, 2, 2, 2])
824            .with_padding([1, 1, 1, 1]);
825
826        let input_size = [10, 10, 10, 10];
827        let kernel_size = [3, 3, 3, 3];
828        let output_size = config.output_size(input_size, kernel_size);
829
830        // (10 + 2*1 - 1*(3-1) - 1) / 2 + 1 = (10 + 2 - 2 - 1) / 2 + 1 = 9/2 + 1 = 5
831        assert_eq!(output_size, [5, 5, 5, 5]);
832    }
833
834    #[test]
835    fn test_conv4d_layer_creation() {
836        let layer = Conv4DLayer::new(8, 16, [3, 3, 3, 3], Conv4DConfig::default()).with_bias(16);
837
838        assert_eq!(layer.weights.shape, [16, 8, 3, 3, 3, 3]);
839        assert!(layer.bias.is_some());
840        assert_eq!(layer.bias.as_ref().unwrap().shape, [16]);
841    }
842
843    #[test]
844    fn test_conv4d_simple() {
845        // Input: [1 batch, 2 channels, 4x4x4x4]
846        let input = Tensor6D::zeros([1, 2, 4, 4, 4, 4]);
847
848        // Kernel: [3 out_channels, 2 in_channels, 2x2x2x2]
849        let kernel = Tensor6D::zeros([3, 2, 2, 2, 2, 2]);
850
851        let config = Conv4DConfig::default();
852        let result = conv4d(&input, &kernel, None, &config);
853
854        assert!(result.is_ok());
855        let output = result.unwrap();
856
857        // Sem padding, stride 1: output = input - kernel + 1 = 4 - 2 + 1 = 3
858        assert_eq!(output.shape, [1, 3, 3, 3, 3, 3]);
859    }
860
861    #[test]
862    fn test_max_pool4d() {
863        let mut input = Tensor6D::zeros([1, 1, 4, 4, 4, 4]);
864
865        // Coloca valor máximo em posição conhecida
866        input.set([0, 0, 1, 1, 1, 1], 10.0).unwrap();
867
868        let result = max_pool4d(&input, [2, 2, 2, 2], None);
869        assert!(result.is_ok());
870
871        let output = result.unwrap();
872        assert_eq!(output.shape, [1, 1, 2, 2, 2, 2]);
873
874        // O valor 10.0 deve aparecer no pooling
875        let pooled_val = output.get([0, 0, 0, 0, 0, 0]).unwrap();
876        assert_eq!(pooled_val, 10.0);
877    }
878
879    #[test]
880    fn test_avg_pool4d() {
881        let mut input = Tensor6D::zeros([1, 1, 4, 4, 4, 4]);
882
883        // Preenche região com valores conhecidos
884        for i in 0..2 {
885            for j in 0..2 {
886                for k in 0..2 {
887                    for l in 0..2 {
888                        input.set([0, 0, i, j, k, l], 2.0).unwrap();
889                    }
890                }
891            }
892        }
893
894        let result = avg_pool4d(&input, [2, 2, 2, 2], None);
895        assert!(result.is_ok());
896
897        let output = result.unwrap();
898        let avg_val = output.get([0, 0, 0, 0, 0, 0]).unwrap();
899
900        // Média de 16 valores = 2.0
901        assert!((avg_val - 2.0).abs() < 1e-10);
902    }
903
904    #[test]
905    fn test_grouped_convolution() {
906        // Grouped convolution com 2 grupos
907        let input = Tensor6D::zeros([1, 4, 4, 4, 4, 4]);
908        let kernel = Tensor6D::zeros([4, 2, 2, 2, 2, 2]); // 4 out, 2 in per group
909
910        let config = Conv4DConfig::default().with_groups(2);
911        let result = conv4d(&input, &kernel, None, &config);
912
913        assert!(result.is_ok());
914    }
915}
916
917#[test]
918fn test_conv4d_backward_pass() {
919    let mut input = Tensor6D::zeros([1, 2, 4, 4, 4, 4]);
920    for i in 0..input.data.len() {
921        input.data[i] = (i as f64) * 0.01;
922    }
923
924    let mut layer = Conv4DLayer::new(2, 3, [2, 2, 2, 2], Conv4DConfig::default());
925    layer.init_xavier();
926
927    let output = layer.forward(&input).unwrap();
928    assert_eq!(output.shape, [1, 3, 3, 3, 3, 3]);
929
930    let mut grad_output = Tensor6D::zeros(output.shape);
931    for i in 0..grad_output.data.len() {
932        grad_output.data[i] = 0.1;
933    }
934
935    let result = layer.backward(&input, &grad_output);
936    assert!(result.is_ok());
937
938    let (grad_input, grad_weights, _) = result.unwrap();
939    assert_eq!(grad_input.shape, input.shape);
940    assert_eq!(grad_weights.shape, layer.weights.shape);
941
942    let grad_input_sum: f64 = grad_input.data.iter().sum();
943    let grad_weights_sum: f64 = grad_weights.data.iter().sum();
944    assert!(grad_input_sum.abs() > 1e-10);
945    assert!(grad_weights_sum.abs() > 1e-10);
946}
947
948#[test]
949fn test_conv4d_backward_with_bias() {
950    let input = Tensor6D::zeros([1, 1, 3, 3, 3, 3]);
951    let mut layer = Conv4DLayer::new(1, 2, [2, 2, 2, 2], Conv4DConfig::default()).with_bias(2);
952    layer.init_he();
953
954    let output = layer.forward(&input).unwrap();
955    let grad_output = Tensor6D::filled(output.shape, 0.5);
956
957    let result = layer.backward(&input, &grad_output);
958    assert!(result.is_ok());
959
960    let (_, _, grad_bias) = result.unwrap();
961    assert!(grad_bias.is_some());
962    assert_eq!(grad_bias.unwrap().shape, [2]);
963}