1use crate::tensor::Tensor;
11use rayon::prelude::*;
12
13pub type Tensor4D = Tensor<4>;
14pub type Tensor5D = Tensor<5>;
15pub type Tensor6D = Tensor<6>;
16
17#[derive(Debug, Clone, Copy)]
19pub struct Conv4DConfig {
20 pub stride: [usize; 4],
22 pub padding: [usize; 4],
24 pub dilation: [usize; 4],
26 pub groups: usize,
28}
29
30impl Default for Conv4DConfig {
31 fn default() -> Self {
32 Self {
33 stride: [1, 1, 1, 1],
34 padding: [0, 0, 0, 0],
35 dilation: [1, 1, 1, 1],
36 groups: 1,
37 }
38 }
39}
40
41impl Conv4DConfig {
42 pub fn with_stride(mut self, stride: [usize; 4]) -> Self {
43 self.stride = stride;
44 self
45 }
46
47 pub fn with_padding(mut self, padding: [usize; 4]) -> Self {
48 self.padding = padding;
49 self
50 }
51
52 pub fn with_dilation(mut self, dilation: [usize; 4]) -> Self {
53 self.dilation = dilation;
54 self
55 }
56
57 pub fn with_groups(mut self, groups: usize) -> Self {
58 self.groups = groups;
59 self
60 }
61
62 pub fn output_size(&self, input_size: [usize; 4], kernel_size: [usize; 4]) -> [usize; 4] {
64 [
65 (input_size[0] + 2 * self.padding[0] - self.dilation[0] * (kernel_size[0] - 1) - 1)
66 / self.stride[0]
67 + 1,
68 (input_size[1] + 2 * self.padding[1] - self.dilation[1] * (kernel_size[1] - 1) - 1)
69 / self.stride[1]
70 + 1,
71 (input_size[2] + 2 * self.padding[2] - self.dilation[2] * (kernel_size[2] - 1) - 1)
72 / self.stride[2]
73 + 1,
74 (input_size[3] + 2 * self.padding[3] - self.dilation[3] * (kernel_size[3] - 1) - 1)
75 / self.stride[3]
76 + 1,
77 ]
78 }
79}
80
81pub struct Conv4DLayer {
89 pub weights: Tensor6D,
91 pub bias: Option<Tensor<1>>,
93 pub config: Conv4DConfig,
95}
96
97impl Conv4DLayer {
98 pub fn new(
106 in_channels: usize,
107 out_channels: usize,
108 kernel_size: [usize; 4],
109 config: Conv4DConfig,
110 ) -> Self {
111 let channels_per_group = in_channels / config.groups;
112 let weights = Tensor6D::zeros([
113 out_channels,
114 channels_per_group,
115 kernel_size[0],
116 kernel_size[1],
117 kernel_size[2],
118 kernel_size[3],
119 ]);
120
121 Self {
122 weights,
123 bias: None,
124 config,
125 }
126 }
127
128 pub fn with_bias(mut self, out_channels: usize) -> Self {
130 self.bias = Some(Tensor::<1>::zeros([out_channels]));
131 self
132 }
133
134 pub fn init_xavier(&mut self) {
136 let fan_in = self.weights.shape[1]
137 * self.weights.shape[2]
138 * self.weights.shape[3]
139 * self.weights.shape[4]
140 * self.weights.shape[5];
141 let fan_out = self.weights.shape[0]
142 * self.weights.shape[2]
143 * self.weights.shape[3]
144 * self.weights.shape[4]
145 * self.weights.shape[5];
146 let scale = (2.0 / (fan_in + fan_out) as f64).sqrt();
147
148 for i in 0..self.weights.data.len() {
150 self.weights.data[i] = (rand::random::<f64>() * 2.0 - 1.0) * scale;
151 }
152 }
153
154 pub fn init_he(&mut self) {
156 let fan_in = self.weights.shape[1]
157 * self.weights.shape[2]
158 * self.weights.shape[3]
159 * self.weights.shape[4]
160 * self.weights.shape[5];
161 let scale = (2.0 / fan_in as f64).sqrt();
162
163 for i in 0..self.weights.data.len() {
164 self.weights.data[i] = rand::random::<f64>() * scale;
165 }
166 }
167
168 pub fn forward(&self, input: &Tensor6D) -> Result<Tensor6D, String> {
170 conv4d(input, &self.weights, self.bias.as_ref(), &self.config)
171 }
172
173 pub fn backward(
184 &self,
185 input: &Tensor6D,
186 grad_output: &Tensor6D,
187 ) -> Result<(Tensor6D, Tensor6D, Option<Tensor<1>>), String> {
188 let batch = input.shape[0];
189 let in_channels = input.shape[1];
190 let input_size = [
191 input.shape[2],
192 input.shape[3],
193 input.shape[4],
194 input.shape[5],
195 ];
196
197 let out_channels = grad_output.shape[1];
198 let output_size = [
199 grad_output.shape[2],
200 grad_output.shape[3],
201 grad_output.shape[4],
202 grad_output.shape[5],
203 ];
204
205 let kernel_size = [
206 self.weights.shape[2],
207 self.weights.shape[3],
208 self.weights.shape[4],
209 self.weights.shape[5],
210 ];
211
212 let mut grad_input = Tensor6D::zeros(input.shape);
214
215 let mut grad_weights = Tensor6D::zeros(self.weights.shape);
217
218 let grad_bias = if self.bias.is_some() {
220 Some(Tensor::<1>::zeros([out_channels]))
221 } else {
222 None
223 };
224
225 let channels_per_group = in_channels / self.config.groups;
226
227 let grad_results: Vec<_> = (0..batch)
229 .into_par_iter()
230 .map(|b| {
231 conv4d_backward_single_batch(
232 input,
233 &self.weights,
234 grad_output,
235 b,
236 in_channels,
237 out_channels,
238 &input_size,
239 &kernel_size,
240 &output_size,
241 channels_per_group,
242 &self.config,
243 )
244 })
245 .collect();
246
247 for (b, (grad_in_batch, grad_w_batch)) in grad_results.into_iter().enumerate() {
249 for ic in 0..in_channels {
251 for i1 in 0..input_size[0] {
252 for i2 in 0..input_size[1] {
253 for i3 in 0..input_size[2] {
254 for i4 in 0..input_size[3] {
255 let idx = ic
256 * input_size[0]
257 * input_size[1]
258 * input_size[2]
259 * input_size[3]
260 + i1 * input_size[1] * input_size[2] * input_size[3]
261 + i2 * input_size[2] * input_size[3]
262 + i3 * input_size[3]
263 + i4;
264 let current = grad_input.get([b, ic, i1, i2, i3, i4]).unwrap();
265 grad_input
266 .set([b, ic, i1, i2, i3, i4], current + grad_in_batch[idx])
267 .unwrap();
268 }
269 }
270 }
271 }
272 }
273
274 for oc in 0..out_channels {
276 for ic in 0..channels_per_group {
277 for k1 in 0..kernel_size[0] {
278 for k2 in 0..kernel_size[1] {
279 for k3 in 0..kernel_size[2] {
280 for k4 in 0..kernel_size[3] {
281 let idx = oc
282 * channels_per_group
283 * kernel_size[0]
284 * kernel_size[1]
285 * kernel_size[2]
286 * kernel_size[3]
287 + ic * kernel_size[0]
288 * kernel_size[1]
289 * kernel_size[2]
290 * kernel_size[3]
291 + k1 * kernel_size[1] * kernel_size[2] * kernel_size[3]
292 + k2 * kernel_size[2] * kernel_size[3]
293 + k3 * kernel_size[3]
294 + k4;
295 let current =
296 grad_weights.get([oc, ic, k1, k2, k3, k4]).unwrap();
297 grad_weights
298 .set([oc, ic, k1, k2, k3, k4], current + grad_w_batch[idx])
299 .unwrap();
300 }
301 }
302 }
303 }
304 }
305 }
306 }
307
308 if let Some(gb) = grad_bias.as_ref() {
310 let mut new_grad_bias = gb.clone();
311 for oc in 0..out_channels {
312 let mut sum = 0.0;
313 for b in 0..batch {
314 for o1 in 0..output_size[0] {
315 for o2 in 0..output_size[1] {
316 for o3 in 0..output_size[2] {
317 for o4 in 0..output_size[3] {
318 sum += grad_output.get([b, oc, o1, o2, o3, o4]).unwrap();
319 }
320 }
321 }
322 }
323 }
324 new_grad_bias.set([oc], sum).unwrap();
325 }
326 return Ok((grad_input, grad_weights, Some(new_grad_bias)));
327 }
328
329 Ok((grad_input, grad_weights, grad_bias))
330 }
331}
332
333pub fn conv4d(
341 input: &Tensor6D,
342 kernel: &Tensor6D,
343 bias: Option<&Tensor<1>>,
344 config: &Conv4DConfig,
345) -> Result<Tensor6D, String> {
346 let batch = input.shape[0];
348 let in_channels = input.shape[1];
349 let input_size = [
350 input.shape[2],
351 input.shape[3],
352 input.shape[4],
353 input.shape[5],
354 ];
355
356 let out_channels = kernel.shape[0];
357 let kernel_channels = kernel.shape[1];
358 let kernel_size = [
359 kernel.shape[2],
360 kernel.shape[3],
361 kernel.shape[4],
362 kernel.shape[5],
363 ];
364
365 if !in_channels.is_multiple_of(config.groups) {
366 return Err("in_channels deve ser divisível por groups".to_string());
367 }
368
369 if !out_channels.is_multiple_of(config.groups) {
370 return Err("out_channels deve ser divisível por groups".to_string());
371 }
372
373 if kernel_channels != in_channels / config.groups {
374 return Err(format!(
375 "kernel channels ({}) deve ser igual a in_channels/groups ({})",
376 kernel_channels,
377 in_channels / config.groups
378 ));
379 }
380
381 let output_size = config.output_size(input_size, kernel_size);
383
384 let mut output = Tensor6D::zeros([
386 batch,
387 out_channels,
388 output_size[0],
389 output_size[1],
390 output_size[2],
391 output_size[3],
392 ]);
393
394 let results: Vec<_> = (0..batch)
396 .into_par_iter()
397 .map(|b| {
398 conv4d_single_batch(
399 input,
400 kernel,
401 b,
402 in_channels,
403 out_channels,
404 &input_size,
405 &kernel_size,
406 &output_size,
407 config,
408 )
409 })
410 .collect();
411
412 for (b, batch_data) in results.into_iter().enumerate() {
414 for oc in 0..out_channels {
415 for o1 in 0..output_size[0] {
416 for o2 in 0..output_size[1] {
417 for o3 in 0..output_size[2] {
418 for o4 in 0..output_size[3] {
419 let idx = oc
420 * output_size[0]
421 * output_size[1]
422 * output_size[2]
423 * output_size[3]
424 + o1 * output_size[1] * output_size[2] * output_size[3]
425 + o2 * output_size[2] * output_size[3]
426 + o3 * output_size[3]
427 + o4;
428 output
429 .set(
430 [b, oc, o1, o2, o3, o4],
431 batch_data[idx] + bias.map_or(0.0, |b| b.get([oc]).unwrap()),
432 )
433 .unwrap();
434 }
435 }
436 }
437 }
438 }
439 }
440
441 Ok(output)
442}
443
444#[allow(clippy::too_many_arguments)]
446fn conv4d_single_batch(
447 input: &Tensor6D,
448 kernel: &Tensor6D,
449 batch_idx: usize,
450 in_channels: usize,
451 out_channels: usize,
452 input_size: &[usize; 4],
453 kernel_size: &[usize; 4],
454 output_size: &[usize; 4],
455 config: &Conv4DConfig,
456) -> Vec<f64> {
457 let mut result =
458 vec![0.0; out_channels * output_size[0] * output_size[1] * output_size[2] * output_size[3]];
459
460 let channels_per_group = in_channels / config.groups;
461
462 for oc in 0..out_channels {
464 let group = oc / (out_channels / config.groups);
465 let group_start = group * channels_per_group;
466 let group_end = group_start + channels_per_group;
467
468 for o1 in 0..output_size[0] {
470 for o2 in 0..output_size[1] {
471 for o3 in 0..output_size[2] {
472 for o4 in 0..output_size[3] {
473 let mut sum = 0.0;
474
475 for ic in group_start..group_end {
477 for k1 in 0..kernel_size[0] {
479 for k2 in 0..kernel_size[1] {
480 for k3 in 0..kernel_size[2] {
481 for k4 in 0..kernel_size[3] {
482 let i1 =
484 o1 * config.stride[0] + k1 * config.dilation[0];
485 let i2 =
486 o2 * config.stride[1] + k2 * config.dilation[1];
487 let i3 =
488 o3 * config.stride[2] + k3 * config.dilation[2];
489 let i4 =
490 o4 * config.stride[3] + k4 * config.dilation[3];
491
492 if i1 >= config.padding[0]
494 && i2 >= config.padding[1]
495 && i3 >= config.padding[2]
496 && i4 >= config.padding[3]
497 {
498 let i1 = i1 - config.padding[0];
499 let i2 = i2 - config.padding[1];
500 let i3 = i3 - config.padding[2];
501 let i4 = i4 - config.padding[3];
502
503 if i1 < input_size[0]
504 && i2 < input_size[1]
505 && i3 < input_size[2]
506 && i4 < input_size[3]
507 {
508 let input_val = input
509 .get([batch_idx, ic, i1, i2, i3, i4])
510 .unwrap();
511 let kernel_val = kernel
512 .get([oc, ic - group_start, k1, k2, k3, k4])
513 .unwrap();
514 sum += input_val * kernel_val;
515 }
516 }
517 }
518 }
519 }
520 }
521 }
522
523 let idx =
524 oc * output_size[0] * output_size[1] * output_size[2] * output_size[3]
525 + o1 * output_size[1] * output_size[2] * output_size[3]
526 + o2 * output_size[2] * output_size[3]
527 + o3 * output_size[3]
528 + o4;
529 result[idx] = sum;
530 }
531 }
532 }
533 }
534 }
535
536 result
537}
538
539#[allow(clippy::too_many_arguments)]
541fn conv4d_backward_single_batch(
542 input: &Tensor6D,
543 weights: &Tensor6D,
544 grad_output: &Tensor6D,
545 batch_idx: usize,
546 in_channels: usize,
547 out_channels: usize,
548 input_size: &[usize; 4],
549 kernel_size: &[usize; 4],
550 output_size: &[usize; 4],
551 channels_per_group: usize,
552 config: &Conv4DConfig,
553) -> (Vec<f64>, Vec<f64>) {
554 let grad_input_size =
555 in_channels * input_size[0] * input_size[1] * input_size[2] * input_size[3];
556 let grad_weights_size = out_channels
557 * channels_per_group
558 * kernel_size[0]
559 * kernel_size[1]
560 * kernel_size[2]
561 * kernel_size[3];
562
563 let mut grad_input = vec![0.0; grad_input_size];
564 let mut grad_weights = vec![0.0; grad_weights_size];
565
566 for oc in 0..out_channels {
568 let group = oc / (out_channels / config.groups);
569 let group_start = group * channels_per_group;
570 let group_end = group_start + channels_per_group;
571
572 for o1 in 0..output_size[0] {
574 for o2 in 0..output_size[1] {
575 for o3 in 0..output_size[2] {
576 for o4 in 0..output_size[3] {
577 let grad_out_val =
578 grad_output.get([batch_idx, oc, o1, o2, o3, o4]).unwrap();
579
580 for ic in group_start..group_end {
582 for k1 in 0..kernel_size[0] {
584 for k2 in 0..kernel_size[1] {
585 for k3 in 0..kernel_size[2] {
586 for k4 in 0..kernel_size[3] {
587 let i1 =
589 o1 * config.stride[0] + k1 * config.dilation[0];
590 let i2 =
591 o2 * config.stride[1] + k2 * config.dilation[1];
592 let i3 =
593 o3 * config.stride[2] + k3 * config.dilation[2];
594 let i4 =
595 o4 * config.stride[3] + k4 * config.dilation[3];
596
597 if i1 >= config.padding[0]
599 && i2 >= config.padding[1]
600 && i3 >= config.padding[2]
601 && i4 >= config.padding[3]
602 {
603 let i1 = i1 - config.padding[0];
604 let i2 = i2 - config.padding[1];
605 let i3 = i3 - config.padding[2];
606 let i4 = i4 - config.padding[3];
607
608 if i1 < input_size[0]
609 && i2 < input_size[1]
610 && i3 < input_size[2]
611 && i4 < input_size[3]
612 {
613 let weight_val = weights
615 .get([oc, ic - group_start, k1, k2, k3, k4])
616 .unwrap();
617 let grad_in_idx = ic
618 * input_size[0]
619 * input_size[1]
620 * input_size[2]
621 * input_size[3]
622 + i1 * input_size[1]
623 * input_size[2]
624 * input_size[3]
625 + i2 * input_size[2] * input_size[3]
626 + i3 * input_size[3]
627 + i4;
628 grad_input[grad_in_idx] +=
629 grad_out_val * weight_val;
630
631 let input_val = input
633 .get([batch_idx, ic, i1, i2, i3, i4])
634 .unwrap();
635 let grad_w_idx = oc
636 * channels_per_group
637 * kernel_size[0]
638 * kernel_size[1]
639 * kernel_size[2]
640 * kernel_size[3]
641 + (ic - group_start)
642 * kernel_size[0]
643 * kernel_size[1]
644 * kernel_size[2]
645 * kernel_size[3]
646 + k1 * kernel_size[1]
647 * kernel_size[2]
648 * kernel_size[3]
649 + k2 * kernel_size[2] * kernel_size[3]
650 + k3 * kernel_size[3]
651 + k4;
652 grad_weights[grad_w_idx] +=
653 grad_out_val * input_val;
654 }
655 }
656 }
657 }
658 }
659 }
660 }
661 }
662 }
663 }
664 }
665 }
666
667 (grad_input, grad_weights)
668}
669
670pub fn max_pool4d(
677 input: &Tensor6D,
678 kernel_size: [usize; 4],
679 stride: Option<[usize; 4]>,
680) -> Result<Tensor6D, String> {
681 let stride = stride.unwrap_or(kernel_size);
682
683 let batch = input.shape[0];
684 let channels = input.shape[1];
685 let input_size = [
686 input.shape[2],
687 input.shape[3],
688 input.shape[4],
689 input.shape[5],
690 ];
691
692 let output_size = [
693 (input_size[0] - kernel_size[0]) / stride[0] + 1,
694 (input_size[1] - kernel_size[1]) / stride[1] + 1,
695 (input_size[2] - kernel_size[2]) / stride[2] + 1,
696 (input_size[3] - kernel_size[3]) / stride[3] + 1,
697 ];
698
699 let mut output = Tensor6D::zeros([
700 batch,
701 channels,
702 output_size[0],
703 output_size[1],
704 output_size[2],
705 output_size[3],
706 ]);
707
708 for b in 0..batch {
709 for c in 0..channels {
710 for o1 in 0..output_size[0] {
711 for o2 in 0..output_size[1] {
712 for o3 in 0..output_size[2] {
713 for o4 in 0..output_size[3] {
714 let mut max_val = f64::NEG_INFINITY;
715
716 for k1 in 0..kernel_size[0] {
717 for k2 in 0..kernel_size[1] {
718 for k3 in 0..kernel_size[2] {
719 for k4 in 0..kernel_size[3] {
720 let i1 = o1 * stride[0] + k1;
721 let i2 = o2 * stride[1] + k2;
722 let i3 = o3 * stride[2] + k3;
723 let i4 = o4 * stride[3] + k4;
724
725 let val = input.get([b, c, i1, i2, i3, i4]).unwrap();
726 if val > max_val {
727 max_val = val;
728 }
729 }
730 }
731 }
732 }
733
734 output.set([b, c, o1, o2, o3, o4], max_val).unwrap();
735 }
736 }
737 }
738 }
739 }
740 }
741
742 Ok(output)
743}
744
745pub fn avg_pool4d(
747 input: &Tensor6D,
748 kernel_size: [usize; 4],
749 stride: Option<[usize; 4]>,
750) -> Result<Tensor6D, String> {
751 let stride = stride.unwrap_or(kernel_size);
752
753 let batch = input.shape[0];
754 let channels = input.shape[1];
755 let input_size = [
756 input.shape[2],
757 input.shape[3],
758 input.shape[4],
759 input.shape[5],
760 ];
761
762 let output_size = [
763 (input_size[0] - kernel_size[0]) / stride[0] + 1,
764 (input_size[1] - kernel_size[1]) / stride[1] + 1,
765 (input_size[2] - kernel_size[2]) / stride[2] + 1,
766 (input_size[3] - kernel_size[3]) / stride[3] + 1,
767 ];
768
769 let mut output = Tensor6D::zeros([
770 batch,
771 channels,
772 output_size[0],
773 output_size[1],
774 output_size[2],
775 output_size[3],
776 ]);
777
778 let kernel_vol = (kernel_size[0] * kernel_size[1] * kernel_size[2] * kernel_size[3]) as f64;
779
780 for b in 0..batch {
781 for c in 0..channels {
782 for o1 in 0..output_size[0] {
783 for o2 in 0..output_size[1] {
784 for o3 in 0..output_size[2] {
785 for o4 in 0..output_size[3] {
786 let mut sum = 0.0;
787
788 for k1 in 0..kernel_size[0] {
789 for k2 in 0..kernel_size[1] {
790 for k3 in 0..kernel_size[2] {
791 for k4 in 0..kernel_size[3] {
792 let i1 = o1 * stride[0] + k1;
793 let i2 = o2 * stride[1] + k2;
794 let i3 = o3 * stride[2] + k3;
795 let i4 = o4 * stride[3] + k4;
796
797 sum += input.get([b, c, i1, i2, i3, i4]).unwrap();
798 }
799 }
800 }
801 }
802
803 output
804 .set([b, c, o1, o2, o3, o4], sum / kernel_vol)
805 .unwrap();
806 }
807 }
808 }
809 }
810 }
811 }
812
813 Ok(output)
814}
815
816#[cfg(test)]
817mod tests {
818 use super::*;
819
820 #[test]
821 fn test_conv4d_config() {
822 let config = Conv4DConfig::default()
823 .with_stride([2, 2, 2, 2])
824 .with_padding([1, 1, 1, 1]);
825
826 let input_size = [10, 10, 10, 10];
827 let kernel_size = [3, 3, 3, 3];
828 let output_size = config.output_size(input_size, kernel_size);
829
830 assert_eq!(output_size, [5, 5, 5, 5]);
832 }
833
834 #[test]
835 fn test_conv4d_layer_creation() {
836 let layer = Conv4DLayer::new(8, 16, [3, 3, 3, 3], Conv4DConfig::default()).with_bias(16);
837
838 assert_eq!(layer.weights.shape, [16, 8, 3, 3, 3, 3]);
839 assert!(layer.bias.is_some());
840 assert_eq!(layer.bias.as_ref().unwrap().shape, [16]);
841 }
842
843 #[test]
844 fn test_conv4d_simple() {
845 let input = Tensor6D::zeros([1, 2, 4, 4, 4, 4]);
847
848 let kernel = Tensor6D::zeros([3, 2, 2, 2, 2, 2]);
850
851 let config = Conv4DConfig::default();
852 let result = conv4d(&input, &kernel, None, &config);
853
854 assert!(result.is_ok());
855 let output = result.unwrap();
856
857 assert_eq!(output.shape, [1, 3, 3, 3, 3, 3]);
859 }
860
861 #[test]
862 fn test_max_pool4d() {
863 let mut input = Tensor6D::zeros([1, 1, 4, 4, 4, 4]);
864
865 input.set([0, 0, 1, 1, 1, 1], 10.0).unwrap();
867
868 let result = max_pool4d(&input, [2, 2, 2, 2], None);
869 assert!(result.is_ok());
870
871 let output = result.unwrap();
872 assert_eq!(output.shape, [1, 1, 2, 2, 2, 2]);
873
874 let pooled_val = output.get([0, 0, 0, 0, 0, 0]).unwrap();
876 assert_eq!(pooled_val, 10.0);
877 }
878
879 #[test]
880 fn test_avg_pool4d() {
881 let mut input = Tensor6D::zeros([1, 1, 4, 4, 4, 4]);
882
883 for i in 0..2 {
885 for j in 0..2 {
886 for k in 0..2 {
887 for l in 0..2 {
888 input.set([0, 0, i, j, k, l], 2.0).unwrap();
889 }
890 }
891 }
892 }
893
894 let result = avg_pool4d(&input, [2, 2, 2, 2], None);
895 assert!(result.is_ok());
896
897 let output = result.unwrap();
898 let avg_val = output.get([0, 0, 0, 0, 0, 0]).unwrap();
899
900 assert!((avg_val - 2.0).abs() < 1e-10);
902 }
903
904 #[test]
905 fn test_grouped_convolution() {
906 let input = Tensor6D::zeros([1, 4, 4, 4, 4, 4]);
908 let kernel = Tensor6D::zeros([4, 2, 2, 2, 2, 2]); let config = Conv4DConfig::default().with_groups(2);
911 let result = conv4d(&input, &kernel, None, &config);
912
913 assert!(result.is_ok());
914 }
915}
916
917#[test]
918fn test_conv4d_backward_pass() {
919 let mut input = Tensor6D::zeros([1, 2, 4, 4, 4, 4]);
920 for i in 0..input.data.len() {
921 input.data[i] = (i as f64) * 0.01;
922 }
923
924 let mut layer = Conv4DLayer::new(2, 3, [2, 2, 2, 2], Conv4DConfig::default());
925 layer.init_xavier();
926
927 let output = layer.forward(&input).unwrap();
928 assert_eq!(output.shape, [1, 3, 3, 3, 3, 3]);
929
930 let mut grad_output = Tensor6D::zeros(output.shape);
931 for i in 0..grad_output.data.len() {
932 grad_output.data[i] = 0.1;
933 }
934
935 let result = layer.backward(&input, &grad_output);
936 assert!(result.is_ok());
937
938 let (grad_input, grad_weights, _) = result.unwrap();
939 assert_eq!(grad_input.shape, input.shape);
940 assert_eq!(grad_weights.shape, layer.weights.shape);
941
942 let grad_input_sum: f64 = grad_input.data.iter().sum();
943 let grad_weights_sum: f64 = grad_weights.data.iter().sum();
944 assert!(grad_input_sum.abs() > 1e-10);
945 assert!(grad_weights_sum.abs() > 1e-10);
946}
947
948#[test]
949fn test_conv4d_backward_with_bias() {
950 let input = Tensor6D::zeros([1, 1, 3, 3, 3, 3]);
951 let mut layer = Conv4DLayer::new(1, 2, [2, 2, 2, 2], Conv4DConfig::default()).with_bias(2);
952 layer.init_he();
953
954 let output = layer.forward(&input).unwrap();
955 let grad_output = Tensor6D::filled(output.shape, 0.5);
956
957 let result = layer.backward(&input, &grad_output);
958 assert!(result.is_ok());
959
960 let (_, _, grad_bias) = result.unwrap();
961 assert!(grad_bias.is_some());
962 assert_eq!(grad_bias.unwrap().shape, [2]);
963}