Skip to main content

ferrotorch_nn/
pooling.rs

1//! Pooling layers: MaxPool1d/2d/3d, AvgPool1d/2d/3d, AdaptiveAvgPool1d/2d/3d,
2//! AdaptiveMaxPool1d/2d/3d, FractionalMaxPool2d, LPPool1d/2d, MaxUnpool2d.
3//!
4//! All are zero-parameter modules operating on `[B, C, *spatial]` tensors.
5//! Each forward pass attaches a `GradFn<T>` for reverse-mode autodiff
6//! when gradient tracking is enabled.
7//!
8//! ## REQ status (per `.design/ferrotorch-nn/pooling.md`)
9//!
10//! | REQ | Status | Evidence |
11//! |---|---|---|
12//! | REQ-1 | SHIPPED | the `MaxPool2d` struct + `impl<T: Float> Module<T> for MaxPool2d` here; non-test consumer: re-export at `ferrotorch-nn/src/lib.rs:237` + `ferrotorch-vision/src/models/resnet.rs:23` + many other vision models |
13//! | REQ-2 | SHIPPED | the `AvgPool2d` struct + `impl<T: Float> Module<T> for AvgPool2d` here; non-test consumer: `ferrotorch-vision/src/models/densenet.rs:43` + `inception.rs:61` + re-export at `lib.rs:237` |
14//! | REQ-3 | SHIPPED | the `AdaptiveAvgPool2d` struct + `impl<T: Float> Module<T> for AdaptiveAvgPool2d` here; non-test consumer: `ferrotorch-vision/src/models/resnet.rs:23` + `convnext.rs:35` + `efficientnet.rs:38` + `mobilenet.rs:55` + `segmentation/aspp.rs:38` + re-export at `lib.rs:237` + the prelude re-export at `lib.rs:286` + `ferrotorch-nn/src/se.rs` (SqueezeExcitation squeeze stage) |
15//! | REQ-4 | SHIPPED | the `MaxPool1d` / `MaxPool3d` / `AvgPool1d` / `AvgPool3d` structs + their `impl<T: Float> Module<T>` blocks here; non-test consumer: re-export at `lib.rs:237` |
16//! | REQ-5 | SHIPPED | the `AdaptiveAvgPool1d` / `AdaptiveAvgPool3d` / `AdaptiveMaxPool1d` / `AdaptiveMaxPool2d` / `AdaptiveMaxPool3d` structs + their Module impls here; non-test consumer: re-export at `lib.rs:237` |
17//! | REQ-6 | SHIPPED | the `FractionalMaxPool2d` struct + `impl<T: Float> Module<T>` here; non-test consumer: re-export at `lib.rs:237` |
18//! | REQ-7 | SHIPPED | the `LPPool1d` / `LPPool2d` structs + their `impl Module<T>` blocks here; non-test consumer: re-export at `lib.rs:237` |
19//! | REQ-8 | SHIPPED | the `MaxUnpool2d` struct + `max_unpool2d` functional entry here; non-test consumer: re-export at `lib.rs:237` |
20//! | REQ-9 | SHIPPED | the 14 free `*_pool*<T: Float>` functional entries here; non-test consumer: re-export at `lib.rs:237` |
21//! | REQ-10 | SHIPPED | the `validate_4d`, `validate_pool_params` helpers here; non-test consumer: invoked from every pool forward (re-exported at `lib.rs:237`) |
22//! | REQ-11 | SHIPPED | per-pool `GradFn<T>` types + `Tensor::from_operation` calls here; non-test consumer: re-export at `lib.rs:237` |
23//! | REQ-12 | NOT-STARTED | parity-sweep runner arms for the 10 declared pooling ops not wired — blocker #1458 |
24
25use std::sync::Arc;
26
27use ferrotorch_core::autograd::no_grad::is_grad_enabled;
28use ferrotorch_core::tensor::GradFn;
29use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
30
31use crate::module::Module;
32use crate::parameter::Parameter;
33
34// ===========================================================================
35// Helpers
36// ===========================================================================
37
38/// Compute the output spatial dimension for a standard pooling operation.
39///
40/// `out = (input + 2 * padding - kernel_size) / stride + 1`
41#[inline]
42fn pool_output_size(input: usize, kernel_size: usize, stride: usize, padding: usize) -> usize {
43    (input + 2 * padding - kernel_size) / stride + 1
44}
45
46/// Validate that the input tensor has shape `[B, C, H, W]`.
47fn validate_4d<T: Float>(input: &Tensor<T>) -> FerrotorchResult<(usize, usize, usize, usize)> {
48    let shape = input.shape();
49    if shape.len() != 4 {
50        return Err(FerrotorchError::InvalidArgument {
51            message: format!(
52                "pooling expects 4D input [B, C, H, W], got shape {:?}",
53                shape
54            ),
55        });
56    }
57    Ok((shape[0], shape[1], shape[2], shape[3]))
58}
59
60/// Validate pooling parameters and compute output spatial dimensions.
61fn validate_pool_params(
62    h: usize,
63    w: usize,
64    kernel_size: [usize; 2],
65    stride: [usize; 2],
66    padding: [usize; 2],
67) -> FerrotorchResult<(usize, usize)> {
68    if kernel_size[0] == 0 || kernel_size[1] == 0 {
69        return Err(FerrotorchError::InvalidArgument {
70            message: "kernel_size must be > 0".into(),
71        });
72    }
73    if stride[0] == 0 || stride[1] == 0 {
74        return Err(FerrotorchError::InvalidArgument {
75            message: "stride must be > 0".into(),
76        });
77    }
78    let padded_h = h + 2 * padding[0];
79    let padded_w = w + 2 * padding[1];
80    if padded_h < kernel_size[0] || padded_w < kernel_size[1] {
81        return Err(FerrotorchError::InvalidArgument {
82            message: format!(
83                "padded input ({padded_h}, {padded_w}) smaller than kernel ({}, {})",
84                kernel_size[0], kernel_size[1]
85            ),
86        });
87    }
88    let out_h = pool_output_size(h, kernel_size[0], stride[0], padding[0]);
89    let out_w = pool_output_size(w, kernel_size[1], stride[1], padding[1]);
90    Ok((out_h, out_w))
91}
92
93// ===========================================================================
94// MaxPool2d
95// ===========================================================================
96
97/// 2D max pooling layer.
98///
99/// Slides a kernel window over each `[H, W]` spatial plane, taking the
100/// maximum value in each window. Zero parameters.
101///
102/// Input shape: `[B, C, H, W]`
103/// Output shape: `[B, C, H_out, W_out]`
104#[derive(Debug, Clone)]
105pub struct MaxPool2d {
106    pub kernel_size: [usize; 2],
107    pub stride: [usize; 2],
108    pub padding: [usize; 2],
109}
110
111impl MaxPool2d {
112    /// Create a new `MaxPool2d` layer.
113    ///
114    /// `stride` defaults to `kernel_size` when set to `[0, 0]` (PyTorch convention).
115    pub fn new(kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2]) -> Self {
116        let stride = if stride == [0, 0] {
117            kernel_size
118        } else {
119            stride
120        };
121        Self {
122            kernel_size,
123            stride,
124            padding,
125        }
126    }
127}
128
129impl<T: Float> Module<T> for MaxPool2d {
130    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
131        max_pool2d_forward(input, self.kernel_size, self.stride, self.padding)
132    }
133
134    fn parameters(&self) -> Vec<&Parameter<T>> {
135        vec![]
136    }
137
138    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
139        vec![]
140    }
141
142    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
143        vec![]
144    }
145
146    fn train(&mut self) {}
147    fn eval(&mut self) {}
148
149    fn is_training(&self) -> bool {
150        false
151    }
152}
153
154/// Forward computation for max pooling, returns the output tensor with
155/// gradient tracking when enabled.
156fn max_pool2d_forward<T: Float>(
157    input: &Tensor<T>,
158    kernel_size: [usize; 2],
159    stride: [usize; 2],
160    padding: [usize; 2],
161) -> FerrotorchResult<Tensor<T>> {
162    let (batch, channels, h, w) = validate_4d(input)?;
163    let (out_h, out_w) = validate_pool_params(h, w, kernel_size, stride, padding)?;
164
165    // Save device for restoring on output.
166    let input_device = input.device();
167
168    let data = input.data_vec()?;
169    let total = batch * channels * out_h * out_w;
170    let mut output = vec![<T as num_traits::Zero>::zero(); total];
171    // Store the flat index of the max element within the input for each output element.
172    let mut indices = vec![0usize; total];
173
174    let neg_inf = T::from(-1e38).unwrap();
175
176    for b in 0..batch {
177        for c in 0..channels {
178            for oh in 0..out_h {
179                for ow in 0..out_w {
180                    let out_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
181                    let mut max_val = neg_inf;
182                    let mut max_idx = 0usize;
183
184                    for kh in 0..kernel_size[0] {
185                        for kw in 0..kernel_size[1] {
186                            let ih = oh * stride[0] + kh;
187                            let iw = ow * stride[1] + kw;
188
189                            // Account for padding.
190                            let ih = ih as isize - padding[0] as isize;
191                            let iw = iw as isize - padding[1] as isize;
192
193                            if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
194                                let ih = ih as usize;
195                                let iw = iw as usize;
196                                let in_idx = ((b * channels + c) * h + ih) * w + iw;
197                                let val = data[in_idx];
198                                if val > max_val {
199                                    max_val = val;
200                                    max_idx = in_idx;
201                                }
202                            }
203                            // Padded positions have -inf, so they never win.
204                        }
205                    }
206
207                    output[out_idx] = max_val;
208                    indices[out_idx] = max_idx;
209                }
210            }
211        }
212    }
213
214    let out_shape = vec![batch, channels, out_h, out_w];
215    let storage = TensorStorage::cpu(output);
216
217    if is_grad_enabled() && input.requires_grad() {
218        Tensor::from_operation(
219            storage,
220            out_shape,
221            Arc::new(MaxPool2dBackward {
222                input: input.clone(),
223                indices,
224            }),
225        )?
226        .to(input_device) // restore device
227    } else {
228        Tensor::from_storage(storage, out_shape, false)?.to(input_device) // restore device
229    }
230}
231
232/// Backward for `MaxPool2d`.
233///
234/// Routes the upstream gradient to the position of the max element in each
235/// pooling window. All other positions receive zero gradient.
236#[derive(Debug)]
237struct MaxPool2dBackward<T: Float> {
238    input: Tensor<T>,
239    /// For each output element, the flat index into the input where the max lives.
240    indices: Vec<usize>,
241}
242
243impl<T: Float> GradFn<T> for MaxPool2dBackward<T> {
244    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
245        if !self.input.requires_grad() {
246            return Ok(vec![None]);
247        }
248
249        let go_data = grad_output.data_vec()?;
250        let input_numel = self.input.numel();
251        let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_numel];
252
253        for (out_idx, &in_idx) in self.indices.iter().enumerate() {
254            grad_input[in_idx] += go_data[out_idx];
255        }
256
257        let grad_tensor = Tensor::from_storage(
258            TensorStorage::cpu(grad_input),
259            self.input.shape().to_vec(),
260            false,
261        )?;
262        Ok(vec![Some(grad_tensor)])
263    }
264
265    fn inputs(&self) -> Vec<&Tensor<T>> {
266        vec![&self.input]
267    }
268
269    fn name(&self) -> &'static str {
270        "MaxPool2dBackward"
271    }
272}
273
274// ===========================================================================
275// AvgPool2d
276// ===========================================================================
277
278/// 2D average pooling layer.
279///
280/// Slides a kernel window over each `[H, W]` spatial plane, computing
281/// the arithmetic mean of each window. Zero parameters.
282///
283/// Input shape: `[B, C, H, W]`
284/// Output shape: `[B, C, H_out, W_out]`
285#[derive(Debug, Clone)]
286pub struct AvgPool2d {
287    pub kernel_size: [usize; 2],
288    pub stride: [usize; 2],
289    pub padding: [usize; 2],
290}
291
292impl AvgPool2d {
293    /// Create a new `AvgPool2d` layer.
294    ///
295    /// `stride` defaults to `kernel_size` when set to `[0, 0]`.
296    pub fn new(kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2]) -> Self {
297        let stride = if stride == [0, 0] {
298            kernel_size
299        } else {
300            stride
301        };
302        Self {
303            kernel_size,
304            stride,
305            padding,
306        }
307    }
308}
309
310impl<T: Float> Module<T> for AvgPool2d {
311    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
312        avg_pool2d_forward(input, self.kernel_size, self.stride, self.padding)
313    }
314
315    fn parameters(&self) -> Vec<&Parameter<T>> {
316        vec![]
317    }
318
319    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
320        vec![]
321    }
322
323    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
324        vec![]
325    }
326
327    fn train(&mut self) {}
328    fn eval(&mut self) {}
329
330    fn is_training(&self) -> bool {
331        false
332    }
333}
334
335/// Forward computation for average pooling.
336fn avg_pool2d_forward<T: Float>(
337    input: &Tensor<T>,
338    kernel_size: [usize; 2],
339    stride: [usize; 2],
340    padding: [usize; 2],
341) -> FerrotorchResult<Tensor<T>> {
342    let (batch, channels, h, w) = validate_4d(input)?;
343    let (out_h, out_w) = validate_pool_params(h, w, kernel_size, stride, padding)?;
344
345    // Save device for restoring on output.
346    let input_device = input.device();
347
348    let data = input.data_vec()?;
349    let total = batch * channels * out_h * out_w;
350    let mut output = vec![<T as num_traits::Zero>::zero(); total];
351
352    let kernel_area = T::from(kernel_size[0] * kernel_size[1]).unwrap();
353
354    for b in 0..batch {
355        for c in 0..channels {
356            for oh in 0..out_h {
357                for ow in 0..out_w {
358                    let out_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
359                    let mut sum = <T as num_traits::Zero>::zero();
360
361                    for kh in 0..kernel_size[0] {
362                        for kw in 0..kernel_size[1] {
363                            let ih = oh * stride[0] + kh;
364                            let iw = ow * stride[1] + kw;
365                            let ih = ih as isize - padding[0] as isize;
366                            let iw = iw as isize - padding[1] as isize;
367
368                            if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
369                                let ih = ih as usize;
370                                let iw = iw as usize;
371                                let in_idx = ((b * channels + c) * h + ih) * w + iw;
372                                sum += data[in_idx];
373                            }
374                            // Padded positions contribute 0, but we still divide
375                            // by the full kernel area (count_include_pad = true).
376                        }
377                    }
378
379                    output[out_idx] = sum / kernel_area;
380                }
381            }
382        }
383    }
384
385    let out_shape = vec![batch, channels, out_h, out_w];
386    let storage = TensorStorage::cpu(output);
387
388    if is_grad_enabled() && input.requires_grad() {
389        Tensor::from_operation(
390            storage,
391            out_shape,
392            Arc::new(AvgPool2dBackward {
393                input: input.clone(),
394                kernel_size,
395                stride,
396                padding,
397            }),
398        )?
399        .to(input_device) // restore device
400    } else {
401        Tensor::from_storage(storage, out_shape, false)?.to(input_device) // restore device
402    }
403}
404
405/// Backward for `AvgPool2d`.
406///
407/// Distributes the upstream gradient evenly to all input positions that
408/// contributed to each output window, dividing by the kernel area.
409#[derive(Debug)]
410struct AvgPool2dBackward<T: Float> {
411    input: Tensor<T>,
412    kernel_size: [usize; 2],
413    stride: [usize; 2],
414    padding: [usize; 2],
415}
416
417impl<T: Float> GradFn<T> for AvgPool2dBackward<T> {
418    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
419        if !self.input.requires_grad() {
420            return Ok(vec![None]);
421        }
422
423        let go_data = grad_output.data_vec()?;
424        let in_shape = self.input.shape();
425        let (batch, channels, h, w) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
426        let out_h = pool_output_size(h, self.kernel_size[0], self.stride[0], self.padding[0]);
427        let out_w = pool_output_size(w, self.kernel_size[1], self.stride[1], self.padding[1]);
428
429        let mut grad_input = vec![<T as num_traits::Zero>::zero(); batch * channels * h * w];
430        let kernel_area = T::from(self.kernel_size[0] * self.kernel_size[1]).unwrap();
431
432        for b in 0..batch {
433            for c in 0..channels {
434                for oh in 0..out_h {
435                    for ow in 0..out_w {
436                        let out_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
437                        let grad_val = go_data[out_idx] / kernel_area;
438
439                        for kh in 0..self.kernel_size[0] {
440                            for kw in 0..self.kernel_size[1] {
441                                let ih =
442                                    (oh * self.stride[0] + kh) as isize - self.padding[0] as isize;
443                                let iw =
444                                    (ow * self.stride[1] + kw) as isize - self.padding[1] as isize;
445
446                                if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
447                                    let ih = ih as usize;
448                                    let iw = iw as usize;
449                                    let in_idx = ((b * channels + c) * h + ih) * w + iw;
450                                    grad_input[in_idx] += grad_val;
451                                }
452                            }
453                        }
454                    }
455                }
456            }
457        }
458
459        // The forward emits its result on the input's device (line ~382
460        // and ~384 of this file: `to(input_device)` after the host
461        // computation). The backward must match — otherwise a CUDA
462        // forward followed by a CPU gradient breaks every downstream
463        // op that branches on device-residency (e.g. relu_backward
464        // hits its `NotImplementedOnCuda` arm because grad_output
465        // disagrees with the saved CUDA `input`). Surfaced by
466        // `gpu_cnn_training_smoke` in
467        // `ferrotorch/tests/gpu_training.rs` (#749 Section B).
468        let grad_tensor = Tensor::from_storage(
469            TensorStorage::cpu(grad_input),
470            self.input.shape().to_vec(),
471            false,
472        )?
473        .to(self.input.device())?;
474        Ok(vec![Some(grad_tensor)])
475    }
476
477    fn inputs(&self) -> Vec<&Tensor<T>> {
478        vec![&self.input]
479    }
480
481    fn name(&self) -> &'static str {
482        "AvgPool2dBackward"
483    }
484}
485
486// ===========================================================================
487// AdaptiveAvgPool2d
488// ===========================================================================
489
490/// 2D adaptive average pooling layer.
491///
492/// Dynamically computes kernel size and stride to produce the target
493/// `output_size` regardless of input spatial dimensions. Zero parameters.
494///
495/// Input shape: `[B, C, H, W]`
496/// Output shape: `[B, C, output_size.0, output_size.1]`
497#[derive(Debug, Clone)]
498pub struct AdaptiveAvgPool2d {
499    pub output_size: (usize, usize),
500}
501
502impl AdaptiveAvgPool2d {
503    /// Create a new `AdaptiveAvgPool2d` targeting the given output spatial size.
504    pub fn new(output_size: (usize, usize)) -> Self {
505        Self { output_size }
506    }
507}
508
509impl<T: Float> Module<T> for AdaptiveAvgPool2d {
510    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
511        adaptive_avg_pool2d_forward(input, self.output_size)
512    }
513
514    fn parameters(&self) -> Vec<&Parameter<T>> {
515        vec![]
516    }
517
518    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
519        vec![]
520    }
521
522    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
523        vec![]
524    }
525
526    fn train(&mut self) {}
527    fn eval(&mut self) {}
528
529    fn is_training(&self) -> bool {
530        false
531    }
532}
533
534/// Compute the start index for adaptive pooling window.
535///
536/// Uses the same formula as PyTorch: `start_i = floor(i * input_size / output_size)`.
537#[inline]
538fn adaptive_start(idx: usize, input_size: usize, output_size: usize) -> usize {
539    (idx * input_size) / output_size
540}
541
542/// Compute the end index for adaptive pooling window.
543///
544/// `end_i = ceil((i + 1) * input_size / output_size)`
545#[inline]
546fn adaptive_end(idx: usize, input_size: usize, output_size: usize) -> usize {
547    ((idx + 1) * input_size).div_ceil(output_size)
548}
549
550/// Forward computation for adaptive average pooling.
551fn adaptive_avg_pool2d_forward<T: Float>(
552    input: &Tensor<T>,
553    output_size: (usize, usize),
554) -> FerrotorchResult<Tensor<T>> {
555    let (batch, channels, h, w) = validate_4d(input)?;
556    let (out_h, out_w) = output_size;
557
558    if out_h == 0 || out_w == 0 {
559        return Err(FerrotorchError::InvalidArgument {
560            message: "adaptive output_size must be > 0".into(),
561        });
562    }
563
564    // Save device for restoring on output.
565    let input_device = input.device();
566
567    let data = input.data_vec()?;
568    let total = batch * channels * out_h * out_w;
569    let mut output = vec![<T as num_traits::Zero>::zero(); total];
570
571    for b in 0..batch {
572        for c in 0..channels {
573            for oh in 0..out_h {
574                let h_start = adaptive_start(oh, h, out_h);
575                let h_end = adaptive_end(oh, h, out_h);
576
577                for ow in 0..out_w {
578                    let w_start = adaptive_start(ow, w, out_w);
579                    let w_end = adaptive_end(ow, w, out_w);
580
581                    let window_area = (h_end - h_start) * (w_end - w_start);
582                    let mut sum = <T as num_traits::Zero>::zero();
583
584                    for ih in h_start..h_end {
585                        for iw in w_start..w_end {
586                            let in_idx = ((b * channels + c) * h + ih) * w + iw;
587                            sum += data[in_idx];
588                        }
589                    }
590
591                    let out_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
592                    output[out_idx] = sum / T::from(window_area).unwrap();
593                }
594            }
595        }
596    }
597
598    let out_shape = vec![batch, channels, out_h, out_w];
599    let storage = TensorStorage::cpu(output);
600
601    if is_grad_enabled() && input.requires_grad() {
602        Tensor::from_operation(
603            storage,
604            out_shape,
605            Arc::new(AdaptiveAvgPool2dBackward {
606                input: input.clone(),
607                output_size,
608            }),
609        )?
610        .to(input_device) // restore device
611    } else {
612        Tensor::from_storage(storage, out_shape, false)?.to(input_device) // restore device
613    }
614}
615
616/// Backward for `AdaptiveAvgPool2d`.
617///
618/// For each output element, distributes the upstream gradient evenly across
619/// the input positions in its adaptive window.
620#[derive(Debug)]
621struct AdaptiveAvgPool2dBackward<T: Float> {
622    input: Tensor<T>,
623    output_size: (usize, usize),
624}
625
626impl<T: Float> GradFn<T> for AdaptiveAvgPool2dBackward<T> {
627    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
628        if !self.input.requires_grad() {
629            return Ok(vec![None]);
630        }
631
632        let go_data = grad_output.data_vec()?;
633        let in_shape = self.input.shape();
634        let (batch, channels, h, w) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
635        let (out_h, out_w) = self.output_size;
636
637        let mut grad_input = vec![<T as num_traits::Zero>::zero(); batch * channels * h * w];
638
639        for b in 0..batch {
640            for c in 0..channels {
641                for oh in 0..out_h {
642                    let h_start = adaptive_start(oh, h, out_h);
643                    let h_end = adaptive_end(oh, h, out_h);
644
645                    for ow in 0..out_w {
646                        let w_start = adaptive_start(ow, w, out_w);
647                        let w_end = adaptive_end(ow, w, out_w);
648
649                        let window_area = (h_end - h_start) * (w_end - w_start);
650                        let out_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
651                        let grad_val = go_data[out_idx] / T::from(window_area).unwrap();
652
653                        for ih in h_start..h_end {
654                            for iw in w_start..w_end {
655                                let in_idx = ((b * channels + c) * h + ih) * w + iw;
656                                grad_input[in_idx] += grad_val;
657                            }
658                        }
659                    }
660                }
661            }
662        }
663
664        let grad_tensor = Tensor::from_storage(
665            TensorStorage::cpu(grad_input),
666            self.input.shape().to_vec(),
667            false,
668        )?;
669        Ok(vec![Some(grad_tensor)])
670    }
671
672    fn inputs(&self) -> Vec<&Tensor<T>> {
673        vec![&self.input]
674    }
675
676    fn name(&self) -> &'static str {
677        "AdaptiveAvgPool2dBackward"
678    }
679}
680
681// ===========================================================================
682// 1-D helpers — CL-315
683// ===========================================================================
684
685/// Validate that the input tensor has shape `[B, C, L]`.
686fn validate_3d<T: Float>(input: &Tensor<T>) -> FerrotorchResult<(usize, usize, usize)> {
687    let shape = input.shape();
688    if shape.len() != 3 {
689        return Err(FerrotorchError::InvalidArgument {
690            message: format!(
691                "1D pooling expects 3D input [B, C, L], got shape {:?}",
692                shape
693            ),
694        });
695    }
696    Ok((shape[0], shape[1], shape[2]))
697}
698
699/// Validate 1D pooling parameters and compute output length.
700fn validate_pool_params_1d(
701    l: usize,
702    kernel_size: usize,
703    stride: usize,
704    padding: usize,
705) -> FerrotorchResult<usize> {
706    if kernel_size == 0 {
707        return Err(FerrotorchError::InvalidArgument {
708            message: "kernel_size must be > 0".into(),
709        });
710    }
711    if stride == 0 {
712        return Err(FerrotorchError::InvalidArgument {
713            message: "stride must be > 0".into(),
714        });
715    }
716    let padded = l + 2 * padding;
717    if padded < kernel_size {
718        return Err(FerrotorchError::InvalidArgument {
719            message: format!("padded input ({padded}) smaller than kernel ({kernel_size})"),
720        });
721    }
722    Ok(pool_output_size(l, kernel_size, stride, padding))
723}
724
725// ===========================================================================
726// 3-D helpers — CL-315
727// ===========================================================================
728
729/// Validate that the input tensor has shape `[B, C, D, H, W]`.
730fn validate_5d<T: Float>(
731    input: &Tensor<T>,
732) -> FerrotorchResult<(usize, usize, usize, usize, usize)> {
733    let shape = input.shape();
734    if shape.len() != 5 {
735        return Err(FerrotorchError::InvalidArgument {
736            message: format!(
737                "3D pooling expects 5D input [B, C, D, H, W], got shape {:?}",
738                shape
739            ),
740        });
741    }
742    Ok((shape[0], shape[1], shape[2], shape[3], shape[4]))
743}
744
745/// Validate 3D pooling parameters and compute output spatial dimensions.
746fn validate_pool_params_3d(
747    d: usize,
748    h: usize,
749    w: usize,
750    kernel_size: [usize; 3],
751    stride: [usize; 3],
752    padding: [usize; 3],
753) -> FerrotorchResult<(usize, usize, usize)> {
754    for i in 0..3 {
755        if kernel_size[i] == 0 {
756            return Err(FerrotorchError::InvalidArgument {
757                message: "kernel_size must be > 0".into(),
758            });
759        }
760        if stride[i] == 0 {
761            return Err(FerrotorchError::InvalidArgument {
762                message: "stride must be > 0".into(),
763            });
764        }
765    }
766    let sizes = [d, h, w];
767    for i in 0..3 {
768        let padded = sizes[i] + 2 * padding[i];
769        if padded < kernel_size[i] {
770            return Err(FerrotorchError::InvalidArgument {
771                message: format!(
772                    "padded input dim {i} ({padded}) smaller than kernel ({})",
773                    kernel_size[i]
774                ),
775            });
776        }
777    }
778    let out_d = pool_output_size(d, kernel_size[0], stride[0], padding[0]);
779    let out_h = pool_output_size(h, kernel_size[1], stride[1], padding[1]);
780    let out_w = pool_output_size(w, kernel_size[2], stride[2], padding[2]);
781    Ok((out_d, out_h, out_w))
782}
783
784// ===========================================================================
785// MaxPool1d — CL-315
786// ===========================================================================
787
788/// 1D max pooling layer.
789///
790/// Slides a kernel window over each `[L]` spatial dimension, taking the
791/// maximum value in each window. Zero parameters.
792///
793/// Input shape: `[B, C, L]`
794/// Output shape: `[B, C, L_out]`
795#[derive(Debug, Clone)]
796pub struct MaxPool1d {
797    pub kernel_size: usize,
798    pub stride: usize,
799    pub padding: usize,
800}
801
802impl MaxPool1d {
803    /// Create a new `MaxPool1d` layer.
804    ///
805    /// `stride` defaults to `kernel_size` when set to `0` (PyTorch convention).
806    pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
807        let stride = if stride == 0 { kernel_size } else { stride };
808        Self {
809            kernel_size,
810            stride,
811            padding,
812        }
813    }
814}
815
816impl<T: Float> Module<T> for MaxPool1d {
817    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
818        max_pool1d_forward(input, self.kernel_size, self.stride, self.padding)
819    }
820
821    fn parameters(&self) -> Vec<&Parameter<T>> {
822        vec![]
823    }
824
825    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
826        vec![]
827    }
828
829    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
830        vec![]
831    }
832
833    fn train(&mut self) {}
834    fn eval(&mut self) {}
835
836    fn is_training(&self) -> bool {
837        false
838    }
839}
840
841/// Forward computation for 1D max pooling.
842fn max_pool1d_forward<T: Float>(
843    input: &Tensor<T>,
844    kernel_size: usize,
845    stride: usize,
846    padding: usize,
847) -> FerrotorchResult<Tensor<T>> {
848    let (batch, channels, l) = validate_3d(input)?;
849    let out_l = validate_pool_params_1d(l, kernel_size, stride, padding)?;
850
851    let input_device = input.device();
852    let data = input.data_vec()?;
853    let total = batch * channels * out_l;
854    let mut output = vec![<T as num_traits::Zero>::zero(); total];
855    let mut indices = vec![0usize; total];
856    let neg_inf = T::from(-1e38).unwrap();
857
858    for b in 0..batch {
859        for c in 0..channels {
860            for ol in 0..out_l {
861                let out_idx = (b * channels + c) * out_l + ol;
862                let mut max_val = neg_inf;
863                let mut max_idx = 0usize;
864
865                for k in 0..kernel_size {
866                    let il = ol * stride + k;
867                    let il = il as isize - padding as isize;
868                    if il >= 0 && il < l as isize {
869                        let il = il as usize;
870                        let in_idx = (b * channels + c) * l + il;
871                        let val = data[in_idx];
872                        if val > max_val {
873                            max_val = val;
874                            max_idx = in_idx;
875                        }
876                    }
877                }
878
879                output[out_idx] = max_val;
880                indices[out_idx] = max_idx;
881            }
882        }
883    }
884
885    let out_shape = vec![batch, channels, out_l];
886    let storage = TensorStorage::cpu(output);
887
888    if is_grad_enabled() && input.requires_grad() {
889        Tensor::from_operation(
890            storage,
891            out_shape,
892            Arc::new(MaxPool1dBackward {
893                input: input.clone(),
894                indices,
895            }),
896        )?
897        .to(input_device)
898    } else {
899        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
900    }
901}
902
903/// Backward for `MaxPool1d`.
904#[derive(Debug)]
905struct MaxPool1dBackward<T: Float> {
906    input: Tensor<T>,
907    indices: Vec<usize>,
908}
909
910impl<T: Float> GradFn<T> for MaxPool1dBackward<T> {
911    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
912        if !self.input.requires_grad() {
913            return Ok(vec![None]);
914        }
915
916        let go_data = grad_output.data_vec()?;
917        let input_numel = self.input.numel();
918        let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_numel];
919
920        for (out_idx, &in_idx) in self.indices.iter().enumerate() {
921            grad_input[in_idx] += go_data[out_idx];
922        }
923
924        let grad_tensor = Tensor::from_storage(
925            TensorStorage::cpu(grad_input),
926            self.input.shape().to_vec(),
927            false,
928        )?;
929        Ok(vec![Some(grad_tensor)])
930    }
931
932    fn inputs(&self) -> Vec<&Tensor<T>> {
933        vec![&self.input]
934    }
935
936    fn name(&self) -> &'static str {
937        "MaxPool1dBackward"
938    }
939}
940
941// ===========================================================================
942// MaxPool3d — CL-315
943// ===========================================================================
944
945/// 3D max pooling layer.
946///
947/// Slides a kernel window over each `[D, H, W]` spatial volume, taking the
948/// maximum value in each window. Zero parameters.
949///
950/// Input shape: `[B, C, D, H, W]`
951/// Output shape: `[B, C, D_out, H_out, W_out]`
952#[derive(Debug, Clone)]
953pub struct MaxPool3d {
954    pub kernel_size: [usize; 3],
955    pub stride: [usize; 3],
956    pub padding: [usize; 3],
957}
958
959impl MaxPool3d {
960    /// Create a new `MaxPool3d` layer.
961    ///
962    /// `stride` defaults to `kernel_size` when set to `[0, 0, 0]`.
963    pub fn new(kernel_size: [usize; 3], stride: [usize; 3], padding: [usize; 3]) -> Self {
964        let stride = if stride == [0, 0, 0] {
965            kernel_size
966        } else {
967            stride
968        };
969        Self {
970            kernel_size,
971            stride,
972            padding,
973        }
974    }
975}
976
977impl<T: Float> Module<T> for MaxPool3d {
978    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
979        max_pool3d_forward(input, self.kernel_size, self.stride, self.padding)
980    }
981
982    fn parameters(&self) -> Vec<&Parameter<T>> {
983        vec![]
984    }
985
986    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
987        vec![]
988    }
989
990    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
991        vec![]
992    }
993
994    fn train(&mut self) {}
995    fn eval(&mut self) {}
996
997    fn is_training(&self) -> bool {
998        false
999    }
1000}
1001
1002/// Forward computation for 3D max pooling.
1003fn max_pool3d_forward<T: Float>(
1004    input: &Tensor<T>,
1005    kernel_size: [usize; 3],
1006    stride: [usize; 3],
1007    padding: [usize; 3],
1008) -> FerrotorchResult<Tensor<T>> {
1009    let (batch, channels, d, h, w) = validate_5d(input)?;
1010    let (out_d, out_h, out_w) = validate_pool_params_3d(d, h, w, kernel_size, stride, padding)?;
1011
1012    let input_device = input.device();
1013    let data = input.data_vec()?;
1014    let total = batch * channels * out_d * out_h * out_w;
1015    let mut output = vec![<T as num_traits::Zero>::zero(); total];
1016    let mut indices = vec![0usize; total];
1017    let neg_inf = T::from(-1e38).unwrap();
1018
1019    for b in 0..batch {
1020        for c in 0..channels {
1021            for od in 0..out_d {
1022                for oh in 0..out_h {
1023                    for ow in 0..out_w {
1024                        let out_idx = (((b * channels + c) * out_d + od) * out_h + oh) * out_w + ow;
1025                        let mut max_val = neg_inf;
1026                        let mut max_idx = 0usize;
1027
1028                        for kd in 0..kernel_size[0] {
1029                            let id = (od * stride[0] + kd) as isize - padding[0] as isize;
1030                            if id < 0 || id >= d as isize {
1031                                continue;
1032                            }
1033                            let id = id as usize;
1034                            for kh in 0..kernel_size[1] {
1035                                let ih = (oh * stride[1] + kh) as isize - padding[1] as isize;
1036                                if ih < 0 || ih >= h as isize {
1037                                    continue;
1038                                }
1039                                let ih = ih as usize;
1040                                for kw in 0..kernel_size[2] {
1041                                    let iw = (ow * stride[2] + kw) as isize - padding[2] as isize;
1042                                    if iw < 0 || iw >= w as isize {
1043                                        continue;
1044                                    }
1045                                    let iw = iw as usize;
1046                                    let in_idx = (((b * channels + c) * d + id) * h + ih) * w + iw;
1047                                    let val = data[in_idx];
1048                                    if val > max_val {
1049                                        max_val = val;
1050                                        max_idx = in_idx;
1051                                    }
1052                                }
1053                            }
1054                        }
1055
1056                        output[out_idx] = max_val;
1057                        indices[out_idx] = max_idx;
1058                    }
1059                }
1060            }
1061        }
1062    }
1063
1064    let out_shape = vec![batch, channels, out_d, out_h, out_w];
1065    let storage = TensorStorage::cpu(output);
1066
1067    if is_grad_enabled() && input.requires_grad() {
1068        Tensor::from_operation(
1069            storage,
1070            out_shape,
1071            Arc::new(MaxPool3dBackward {
1072                input: input.clone(),
1073                indices,
1074            }),
1075        )?
1076        .to(input_device)
1077    } else {
1078        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1079    }
1080}
1081
1082/// Backward for `MaxPool3d`.
1083#[derive(Debug)]
1084struct MaxPool3dBackward<T: Float> {
1085    input: Tensor<T>,
1086    indices: Vec<usize>,
1087}
1088
1089impl<T: Float> GradFn<T> for MaxPool3dBackward<T> {
1090    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1091        if !self.input.requires_grad() {
1092            return Ok(vec![None]);
1093        }
1094
1095        let go_data = grad_output.data_vec()?;
1096        let input_numel = self.input.numel();
1097        let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_numel];
1098
1099        for (out_idx, &in_idx) in self.indices.iter().enumerate() {
1100            grad_input[in_idx] += go_data[out_idx];
1101        }
1102
1103        let grad_tensor = Tensor::from_storage(
1104            TensorStorage::cpu(grad_input),
1105            self.input.shape().to_vec(),
1106            false,
1107        )?;
1108        Ok(vec![Some(grad_tensor)])
1109    }
1110
1111    fn inputs(&self) -> Vec<&Tensor<T>> {
1112        vec![&self.input]
1113    }
1114
1115    fn name(&self) -> &'static str {
1116        "MaxPool3dBackward"
1117    }
1118}
1119
1120// ===========================================================================
1121// AvgPool1d — CL-315
1122// ===========================================================================
1123
1124/// 1D average pooling layer.
1125///
1126/// Input shape: `[B, C, L]`
1127/// Output shape: `[B, C, L_out]`
1128#[derive(Debug, Clone)]
1129pub struct AvgPool1d {
1130    pub kernel_size: usize,
1131    pub stride: usize,
1132    pub padding: usize,
1133}
1134
1135impl AvgPool1d {
1136    /// Create a new `AvgPool1d` layer.
1137    ///
1138    /// `stride` defaults to `kernel_size` when set to `0`.
1139    pub fn new(kernel_size: usize, stride: usize, padding: usize) -> Self {
1140        let stride = if stride == 0 { kernel_size } else { stride };
1141        Self {
1142            kernel_size,
1143            stride,
1144            padding,
1145        }
1146    }
1147}
1148
1149impl<T: Float> Module<T> for AvgPool1d {
1150    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1151        avg_pool1d_forward(input, self.kernel_size, self.stride, self.padding)
1152    }
1153
1154    fn parameters(&self) -> Vec<&Parameter<T>> {
1155        vec![]
1156    }
1157
1158    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1159        vec![]
1160    }
1161
1162    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1163        vec![]
1164    }
1165
1166    fn train(&mut self) {}
1167    fn eval(&mut self) {}
1168
1169    fn is_training(&self) -> bool {
1170        false
1171    }
1172}
1173
1174/// Forward computation for 1D average pooling.
1175fn avg_pool1d_forward<T: Float>(
1176    input: &Tensor<T>,
1177    kernel_size: usize,
1178    stride: usize,
1179    padding: usize,
1180) -> FerrotorchResult<Tensor<T>> {
1181    let (batch, channels, l) = validate_3d(input)?;
1182    let out_l = validate_pool_params_1d(l, kernel_size, stride, padding)?;
1183
1184    let input_device = input.device();
1185    let data = input.data_vec()?;
1186    let total = batch * channels * out_l;
1187    let mut output = vec![<T as num_traits::Zero>::zero(); total];
1188    let kernel_area = T::from(kernel_size).unwrap();
1189
1190    for b in 0..batch {
1191        for c in 0..channels {
1192            for ol in 0..out_l {
1193                let out_idx = (b * channels + c) * out_l + ol;
1194                let mut sum = <T as num_traits::Zero>::zero();
1195
1196                for k in 0..kernel_size {
1197                    let il = (ol * stride + k) as isize - padding as isize;
1198                    if il >= 0 && il < l as isize {
1199                        let il = il as usize;
1200                        let in_idx = (b * channels + c) * l + il;
1201                        sum += data[in_idx];
1202                    }
1203                }
1204
1205                output[out_idx] = sum / kernel_area;
1206            }
1207        }
1208    }
1209
1210    let out_shape = vec![batch, channels, out_l];
1211    let storage = TensorStorage::cpu(output);
1212
1213    if is_grad_enabled() && input.requires_grad() {
1214        Tensor::from_operation(
1215            storage,
1216            out_shape,
1217            Arc::new(AvgPool1dBackward {
1218                input: input.clone(),
1219                kernel_size,
1220                stride,
1221                padding,
1222            }),
1223        )?
1224        .to(input_device)
1225    } else {
1226        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1227    }
1228}
1229
1230/// Backward for `AvgPool1d`.
1231#[derive(Debug)]
1232struct AvgPool1dBackward<T: Float> {
1233    input: Tensor<T>,
1234    kernel_size: usize,
1235    stride: usize,
1236    padding: usize,
1237}
1238
1239impl<T: Float> GradFn<T> for AvgPool1dBackward<T> {
1240    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1241        if !self.input.requires_grad() {
1242            return Ok(vec![None]);
1243        }
1244
1245        let go_data = grad_output.data_vec()?;
1246        let in_shape = self.input.shape();
1247        let (batch, channels, l) = (in_shape[0], in_shape[1], in_shape[2]);
1248        let out_l = pool_output_size(l, self.kernel_size, self.stride, self.padding);
1249
1250        let mut grad_input = vec![<T as num_traits::Zero>::zero(); batch * channels * l];
1251        let kernel_area = T::from(self.kernel_size).unwrap();
1252
1253        for b in 0..batch {
1254            for c in 0..channels {
1255                for ol in 0..out_l {
1256                    let out_idx = (b * channels + c) * out_l + ol;
1257                    let grad_val = go_data[out_idx] / kernel_area;
1258
1259                    for k in 0..self.kernel_size {
1260                        let il = (ol * self.stride + k) as isize - self.padding as isize;
1261                        if il >= 0 && il < l as isize {
1262                            let il = il as usize;
1263                            let in_idx = (b * channels + c) * l + il;
1264                            grad_input[in_idx] += grad_val;
1265                        }
1266                    }
1267                }
1268            }
1269        }
1270
1271        let grad_tensor = Tensor::from_storage(
1272            TensorStorage::cpu(grad_input),
1273            self.input.shape().to_vec(),
1274            false,
1275        )?;
1276        Ok(vec![Some(grad_tensor)])
1277    }
1278
1279    fn inputs(&self) -> Vec<&Tensor<T>> {
1280        vec![&self.input]
1281    }
1282
1283    fn name(&self) -> &'static str {
1284        "AvgPool1dBackward"
1285    }
1286}
1287
1288// ===========================================================================
1289// AvgPool3d — CL-315
1290// ===========================================================================
1291
1292/// 3D average pooling layer.
1293///
1294/// Input shape: `[B, C, D, H, W]`
1295/// Output shape: `[B, C, D_out, H_out, W_out]`
1296#[derive(Debug, Clone)]
1297pub struct AvgPool3d {
1298    pub kernel_size: [usize; 3],
1299    pub stride: [usize; 3],
1300    pub padding: [usize; 3],
1301}
1302
1303impl AvgPool3d {
1304    /// Create a new `AvgPool3d` layer.
1305    ///
1306    /// `stride` defaults to `kernel_size` when set to `[0, 0, 0]`.
1307    pub fn new(kernel_size: [usize; 3], stride: [usize; 3], padding: [usize; 3]) -> Self {
1308        let stride = if stride == [0, 0, 0] {
1309            kernel_size
1310        } else {
1311            stride
1312        };
1313        Self {
1314            kernel_size,
1315            stride,
1316            padding,
1317        }
1318    }
1319}
1320
1321impl<T: Float> Module<T> for AvgPool3d {
1322    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1323        avg_pool3d_forward(input, self.kernel_size, self.stride, self.padding)
1324    }
1325
1326    fn parameters(&self) -> Vec<&Parameter<T>> {
1327        vec![]
1328    }
1329
1330    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1331        vec![]
1332    }
1333
1334    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1335        vec![]
1336    }
1337
1338    fn train(&mut self) {}
1339    fn eval(&mut self) {}
1340
1341    fn is_training(&self) -> bool {
1342        false
1343    }
1344}
1345
1346/// Forward computation for 3D average pooling.
1347fn avg_pool3d_forward<T: Float>(
1348    input: &Tensor<T>,
1349    kernel_size: [usize; 3],
1350    stride: [usize; 3],
1351    padding: [usize; 3],
1352) -> FerrotorchResult<Tensor<T>> {
1353    let (batch, channels, d, h, w) = validate_5d(input)?;
1354    let (out_d, out_h, out_w) = validate_pool_params_3d(d, h, w, kernel_size, stride, padding)?;
1355
1356    let input_device = input.device();
1357    let data = input.data_vec()?;
1358    let total = batch * channels * out_d * out_h * out_w;
1359    let mut output = vec![<T as num_traits::Zero>::zero(); total];
1360    let kernel_vol = T::from(kernel_size[0] * kernel_size[1] * kernel_size[2]).unwrap();
1361
1362    for b in 0..batch {
1363        for c in 0..channels {
1364            for od in 0..out_d {
1365                for oh in 0..out_h {
1366                    for ow in 0..out_w {
1367                        let out_idx = (((b * channels + c) * out_d + od) * out_h + oh) * out_w + ow;
1368                        let mut sum = <T as num_traits::Zero>::zero();
1369
1370                        for kd in 0..kernel_size[0] {
1371                            let id = (od * stride[0] + kd) as isize - padding[0] as isize;
1372                            if id < 0 || id >= d as isize {
1373                                continue;
1374                            }
1375                            let id = id as usize;
1376                            for kh in 0..kernel_size[1] {
1377                                let ih = (oh * stride[1] + kh) as isize - padding[1] as isize;
1378                                if ih < 0 || ih >= h as isize {
1379                                    continue;
1380                                }
1381                                let ih = ih as usize;
1382                                for kw in 0..kernel_size[2] {
1383                                    let iw = (ow * stride[2] + kw) as isize - padding[2] as isize;
1384                                    if iw < 0 || iw >= w as isize {
1385                                        continue;
1386                                    }
1387                                    let iw = iw as usize;
1388                                    let in_idx = (((b * channels + c) * d + id) * h + ih) * w + iw;
1389                                    sum += data[in_idx];
1390                                }
1391                            }
1392                        }
1393
1394                        output[out_idx] = sum / kernel_vol;
1395                    }
1396                }
1397            }
1398        }
1399    }
1400
1401    let out_shape = vec![batch, channels, out_d, out_h, out_w];
1402    let storage = TensorStorage::cpu(output);
1403
1404    if is_grad_enabled() && input.requires_grad() {
1405        Tensor::from_operation(
1406            storage,
1407            out_shape,
1408            Arc::new(AvgPool3dBackward {
1409                input: input.clone(),
1410                kernel_size,
1411                stride,
1412                padding,
1413            }),
1414        )?
1415        .to(input_device)
1416    } else {
1417        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1418    }
1419}
1420
1421/// Backward for `AvgPool3d`.
1422#[derive(Debug)]
1423struct AvgPool3dBackward<T: Float> {
1424    input: Tensor<T>,
1425    kernel_size: [usize; 3],
1426    stride: [usize; 3],
1427    padding: [usize; 3],
1428}
1429
1430impl<T: Float> GradFn<T> for AvgPool3dBackward<T> {
1431    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1432        if !self.input.requires_grad() {
1433            return Ok(vec![None]);
1434        }
1435
1436        let go_data = grad_output.data_vec()?;
1437        let in_shape = self.input.shape();
1438        let (batch, channels, d, h, w) = (
1439            in_shape[0],
1440            in_shape[1],
1441            in_shape[2],
1442            in_shape[3],
1443            in_shape[4],
1444        );
1445        let out_d = pool_output_size(d, self.kernel_size[0], self.stride[0], self.padding[0]);
1446        let out_h = pool_output_size(h, self.kernel_size[1], self.stride[1], self.padding[1]);
1447        let out_w = pool_output_size(w, self.kernel_size[2], self.stride[2], self.padding[2]);
1448
1449        let mut grad_input = vec![<T as num_traits::Zero>::zero(); batch * channels * d * h * w];
1450        let kernel_vol =
1451            T::from(self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]).unwrap();
1452
1453        for b in 0..batch {
1454            for c in 0..channels {
1455                for od in 0..out_d {
1456                    for oh in 0..out_h {
1457                        for ow in 0..out_w {
1458                            let out_idx =
1459                                (((b * channels + c) * out_d + od) * out_h + oh) * out_w + ow;
1460                            let grad_val = go_data[out_idx] / kernel_vol;
1461
1462                            for kd in 0..self.kernel_size[0] {
1463                                let id =
1464                                    (od * self.stride[0] + kd) as isize - self.padding[0] as isize;
1465                                if id < 0 || id >= d as isize {
1466                                    continue;
1467                                }
1468                                let id = id as usize;
1469                                for kh in 0..self.kernel_size[1] {
1470                                    let ih = (oh * self.stride[1] + kh) as isize
1471                                        - self.padding[1] as isize;
1472                                    if ih < 0 || ih >= h as isize {
1473                                        continue;
1474                                    }
1475                                    let ih = ih as usize;
1476                                    for kw in 0..self.kernel_size[2] {
1477                                        let iw = (ow * self.stride[2] + kw) as isize
1478                                            - self.padding[2] as isize;
1479                                        if iw < 0 || iw >= w as isize {
1480                                            continue;
1481                                        }
1482                                        let iw = iw as usize;
1483                                        let in_idx =
1484                                            (((b * channels + c) * d + id) * h + ih) * w + iw;
1485                                        grad_input[in_idx] += grad_val;
1486                                    }
1487                                }
1488                            }
1489                        }
1490                    }
1491                }
1492            }
1493        }
1494
1495        let grad_tensor = Tensor::from_storage(
1496            TensorStorage::cpu(grad_input),
1497            self.input.shape().to_vec(),
1498            false,
1499        )?;
1500        Ok(vec![Some(grad_tensor)])
1501    }
1502
1503    fn inputs(&self) -> Vec<&Tensor<T>> {
1504        vec![&self.input]
1505    }
1506
1507    fn name(&self) -> &'static str {
1508        "AvgPool3dBackward"
1509    }
1510}
1511
1512// ===========================================================================
1513// AdaptiveMaxPool2d — CL-315
1514// ===========================================================================
1515
1516/// 2D adaptive max pooling layer.
1517///
1518/// Dynamically computes window boundaries to produce the target `output_size`
1519/// regardless of input spatial dimensions. Returns the pooled tensor and
1520/// stores indices internally for gradient routing.
1521///
1522/// Input shape: `[B, C, H, W]`
1523/// Output shape: `[B, C, output_size.0, output_size.1]`
1524#[derive(Debug, Clone)]
1525pub struct AdaptiveMaxPool2d {
1526    pub output_size: (usize, usize),
1527}
1528
1529impl AdaptiveMaxPool2d {
1530    /// Create a new `AdaptiveMaxPool2d` targeting the given output spatial size.
1531    pub fn new(output_size: (usize, usize)) -> Self {
1532        Self { output_size }
1533    }
1534}
1535
1536impl<T: Float> Module<T> for AdaptiveMaxPool2d {
1537    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1538        adaptive_max_pool2d_forward(input, self.output_size)
1539    }
1540
1541    fn parameters(&self) -> Vec<&Parameter<T>> {
1542        vec![]
1543    }
1544
1545    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1546        vec![]
1547    }
1548
1549    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1550        vec![]
1551    }
1552
1553    fn train(&mut self) {}
1554    fn eval(&mut self) {}
1555
1556    fn is_training(&self) -> bool {
1557        false
1558    }
1559}
1560
1561/// Forward computation for adaptive max pooling.
1562fn adaptive_max_pool2d_forward<T: Float>(
1563    input: &Tensor<T>,
1564    output_size: (usize, usize),
1565) -> FerrotorchResult<Tensor<T>> {
1566    let (batch, channels, h, w) = validate_4d(input)?;
1567    let (out_h, out_w) = output_size;
1568
1569    if out_h == 0 || out_w == 0 {
1570        return Err(FerrotorchError::InvalidArgument {
1571            message: "adaptive output_size must be > 0".into(),
1572        });
1573    }
1574
1575    let input_device = input.device();
1576    let data = input.data_vec()?;
1577    let total = batch * channels * out_h * out_w;
1578    let mut output = vec![<T as num_traits::Zero>::zero(); total];
1579    let mut indices = vec![0usize; total];
1580    let neg_inf = T::from(-1e38).unwrap();
1581
1582    for b in 0..batch {
1583        for c in 0..channels {
1584            for oh in 0..out_h {
1585                let h_start = adaptive_start(oh, h, out_h);
1586                let h_end = adaptive_end(oh, h, out_h);
1587
1588                for ow in 0..out_w {
1589                    let w_start = adaptive_start(ow, w, out_w);
1590                    let w_end = adaptive_end(ow, w, out_w);
1591
1592                    let out_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
1593                    let mut max_val = neg_inf;
1594                    let mut max_idx = 0usize;
1595
1596                    for ih in h_start..h_end {
1597                        for iw in w_start..w_end {
1598                            let in_idx = ((b * channels + c) * h + ih) * w + iw;
1599                            let val = data[in_idx];
1600                            if val > max_val {
1601                                max_val = val;
1602                                max_idx = in_idx;
1603                            }
1604                        }
1605                    }
1606
1607                    output[out_idx] = max_val;
1608                    indices[out_idx] = max_idx;
1609                }
1610            }
1611        }
1612    }
1613
1614    let out_shape = vec![batch, channels, out_h, out_w];
1615    let storage = TensorStorage::cpu(output);
1616
1617    if is_grad_enabled() && input.requires_grad() {
1618        Tensor::from_operation(
1619            storage,
1620            out_shape,
1621            Arc::new(AdaptiveMaxPool2dBackward {
1622                input: input.clone(),
1623                indices,
1624            }),
1625        )?
1626        .to(input_device)
1627    } else {
1628        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1629    }
1630}
1631
1632/// Backward for `AdaptiveMaxPool2d`.
1633#[derive(Debug)]
1634struct AdaptiveMaxPool2dBackward<T: Float> {
1635    input: Tensor<T>,
1636    indices: Vec<usize>,
1637}
1638
1639impl<T: Float> GradFn<T> for AdaptiveMaxPool2dBackward<T> {
1640    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1641        if !self.input.requires_grad() {
1642            return Ok(vec![None]);
1643        }
1644
1645        let go_data = grad_output.data_vec()?;
1646        let input_numel = self.input.numel();
1647        let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_numel];
1648
1649        for (out_idx, &in_idx) in self.indices.iter().enumerate() {
1650            grad_input[in_idx] += go_data[out_idx];
1651        }
1652
1653        let grad_tensor = Tensor::from_storage(
1654            TensorStorage::cpu(grad_input),
1655            self.input.shape().to_vec(),
1656            false,
1657        )?;
1658        Ok(vec![Some(grad_tensor)])
1659    }
1660
1661    fn inputs(&self) -> Vec<&Tensor<T>> {
1662        vec![&self.input]
1663    }
1664
1665    fn name(&self) -> &'static str {
1666        "AdaptiveMaxPool2dBackward"
1667    }
1668}
1669
1670// ===========================================================================
1671// AdaptiveAvgPool1d — CL-315
1672// ===========================================================================
1673
1674/// 1D adaptive average pooling layer.
1675///
1676/// Input shape: `[B, C, L]`
1677/// Output shape: `[B, C, output_size]`
1678#[derive(Debug, Clone)]
1679pub struct AdaptiveAvgPool1d {
1680    pub output_size: usize,
1681}
1682
1683impl AdaptiveAvgPool1d {
1684    /// Create a new `AdaptiveAvgPool1d` targeting the given output length.
1685    pub fn new(output_size: usize) -> Self {
1686        Self { output_size }
1687    }
1688}
1689
1690impl<T: Float> Module<T> for AdaptiveAvgPool1d {
1691    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1692        adaptive_avg_pool1d_forward(input, self.output_size)
1693    }
1694
1695    fn parameters(&self) -> Vec<&Parameter<T>> {
1696        vec![]
1697    }
1698
1699    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1700        vec![]
1701    }
1702
1703    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1704        vec![]
1705    }
1706
1707    fn train(&mut self) {}
1708    fn eval(&mut self) {}
1709
1710    fn is_training(&self) -> bool {
1711        false
1712    }
1713}
1714
1715/// Forward computation for 1D adaptive average pooling.
1716fn adaptive_avg_pool1d_forward<T: Float>(
1717    input: &Tensor<T>,
1718    output_size: usize,
1719) -> FerrotorchResult<Tensor<T>> {
1720    let (batch, channels, l) = validate_3d(input)?;
1721    let out_l = output_size;
1722
1723    if out_l == 0 {
1724        return Err(FerrotorchError::InvalidArgument {
1725            message: "adaptive output_size must be > 0".into(),
1726        });
1727    }
1728
1729    let input_device = input.device();
1730    let data = input.data_vec()?;
1731    let total = batch * channels * out_l;
1732    let mut output = vec![<T as num_traits::Zero>::zero(); total];
1733
1734    for b in 0..batch {
1735        for c in 0..channels {
1736            for ol in 0..out_l {
1737                let l_start = adaptive_start(ol, l, out_l);
1738                let l_end = adaptive_end(ol, l, out_l);
1739                let window = l_end - l_start;
1740                let mut sum = <T as num_traits::Zero>::zero();
1741
1742                for il in l_start..l_end {
1743                    let in_idx = (b * channels + c) * l + il;
1744                    sum += data[in_idx];
1745                }
1746
1747                let out_idx = (b * channels + c) * out_l + ol;
1748                output[out_idx] = sum / T::from(window).unwrap();
1749            }
1750        }
1751    }
1752
1753    let out_shape = vec![batch, channels, out_l];
1754    let storage = TensorStorage::cpu(output);
1755
1756    if is_grad_enabled() && input.requires_grad() {
1757        Tensor::from_operation(
1758            storage,
1759            out_shape,
1760            Arc::new(AdaptiveAvgPool1dBackward {
1761                input: input.clone(),
1762                output_size,
1763            }),
1764        )?
1765        .to(input_device)
1766    } else {
1767        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1768    }
1769}
1770
1771/// Backward for `AdaptiveAvgPool1d`.
1772#[derive(Debug)]
1773struct AdaptiveAvgPool1dBackward<T: Float> {
1774    input: Tensor<T>,
1775    output_size: usize,
1776}
1777
1778impl<T: Float> GradFn<T> for AdaptiveAvgPool1dBackward<T> {
1779    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1780        if !self.input.requires_grad() {
1781            return Ok(vec![None]);
1782        }
1783
1784        let go_data = grad_output.data_vec()?;
1785        let in_shape = self.input.shape();
1786        let (batch, channels, l) = (in_shape[0], in_shape[1], in_shape[2]);
1787        let out_l = self.output_size;
1788
1789        let mut grad_input = vec![<T as num_traits::Zero>::zero(); batch * channels * l];
1790
1791        for b in 0..batch {
1792            for c in 0..channels {
1793                for ol in 0..out_l {
1794                    let l_start = adaptive_start(ol, l, out_l);
1795                    let l_end = adaptive_end(ol, l, out_l);
1796                    let window = l_end - l_start;
1797                    let out_idx = (b * channels + c) * out_l + ol;
1798                    let grad_val = go_data[out_idx] / T::from(window).unwrap();
1799
1800                    for il in l_start..l_end {
1801                        let in_idx = (b * channels + c) * l + il;
1802                        grad_input[in_idx] += grad_val;
1803                    }
1804                }
1805            }
1806        }
1807
1808        let grad_tensor = Tensor::from_storage(
1809            TensorStorage::cpu(grad_input),
1810            self.input.shape().to_vec(),
1811            false,
1812        )?;
1813        Ok(vec![Some(grad_tensor)])
1814    }
1815
1816    fn inputs(&self) -> Vec<&Tensor<T>> {
1817        vec![&self.input]
1818    }
1819
1820    fn name(&self) -> &'static str {
1821        "AdaptiveAvgPool1dBackward"
1822    }
1823}
1824
1825// ===========================================================================
1826// AdaptiveAvgPool3d — CL-315
1827// ===========================================================================
1828
1829/// 3D adaptive average pooling layer.
1830///
1831/// Input shape: `[B, C, D, H, W]`
1832/// Output shape: `[B, C, output_size.0, output_size.1, output_size.2]`
1833#[derive(Debug, Clone)]
1834pub struct AdaptiveAvgPool3d {
1835    pub output_size: (usize, usize, usize),
1836}
1837
1838impl AdaptiveAvgPool3d {
1839    /// Create a new `AdaptiveAvgPool3d` targeting the given output spatial size.
1840    pub fn new(output_size: (usize, usize, usize)) -> Self {
1841        Self { output_size }
1842    }
1843}
1844
1845impl<T: Float> Module<T> for AdaptiveAvgPool3d {
1846    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1847        adaptive_avg_pool3d_forward(input, self.output_size)
1848    }
1849
1850    fn parameters(&self) -> Vec<&Parameter<T>> {
1851        vec![]
1852    }
1853
1854    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1855        vec![]
1856    }
1857
1858    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1859        vec![]
1860    }
1861
1862    fn train(&mut self) {}
1863    fn eval(&mut self) {}
1864
1865    fn is_training(&self) -> bool {
1866        false
1867    }
1868}
1869
1870/// Forward computation for 3D adaptive average pooling.
1871fn adaptive_avg_pool3d_forward<T: Float>(
1872    input: &Tensor<T>,
1873    output_size: (usize, usize, usize),
1874) -> FerrotorchResult<Tensor<T>> {
1875    let (batch, channels, d, h, w) = validate_5d(input)?;
1876    let (out_d, out_h, out_w) = output_size;
1877
1878    if out_d == 0 || out_h == 0 || out_w == 0 {
1879        return Err(FerrotorchError::InvalidArgument {
1880            message: "adaptive output_size must be > 0".into(),
1881        });
1882    }
1883
1884    let input_device = input.device();
1885    let data = input.data_vec()?;
1886    let total = batch * channels * out_d * out_h * out_w;
1887    let mut output = vec![<T as num_traits::Zero>::zero(); total];
1888
1889    for b in 0..batch {
1890        for c in 0..channels {
1891            for od in 0..out_d {
1892                let d_start = adaptive_start(od, d, out_d);
1893                let d_end = adaptive_end(od, d, out_d);
1894
1895                for oh in 0..out_h {
1896                    let h_start = adaptive_start(oh, h, out_h);
1897                    let h_end = adaptive_end(oh, h, out_h);
1898
1899                    for ow in 0..out_w {
1900                        let w_start = adaptive_start(ow, w, out_w);
1901                        let w_end = adaptive_end(ow, w, out_w);
1902
1903                        let window_vol = (d_end - d_start) * (h_end - h_start) * (w_end - w_start);
1904                        let mut sum = <T as num_traits::Zero>::zero();
1905
1906                        for id in d_start..d_end {
1907                            for ih in h_start..h_end {
1908                                for iw in w_start..w_end {
1909                                    let in_idx = (((b * channels + c) * d + id) * h + ih) * w + iw;
1910                                    sum += data[in_idx];
1911                                }
1912                            }
1913                        }
1914
1915                        let out_idx = (((b * channels + c) * out_d + od) * out_h + oh) * out_w + ow;
1916                        output[out_idx] = sum / T::from(window_vol).unwrap();
1917                    }
1918                }
1919            }
1920        }
1921    }
1922
1923    let out_shape = vec![batch, channels, out_d, out_h, out_w];
1924    let storage = TensorStorage::cpu(output);
1925
1926    if is_grad_enabled() && input.requires_grad() {
1927        Tensor::from_operation(
1928            storage,
1929            out_shape,
1930            Arc::new(AdaptiveAvgPool3dBackward {
1931                input: input.clone(),
1932                output_size,
1933            }),
1934        )?
1935        .to(input_device)
1936    } else {
1937        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1938    }
1939}
1940
1941/// Backward for `AdaptiveAvgPool3d`.
1942#[derive(Debug)]
1943struct AdaptiveAvgPool3dBackward<T: Float> {
1944    input: Tensor<T>,
1945    output_size: (usize, usize, usize),
1946}
1947
1948impl<T: Float> GradFn<T> for AdaptiveAvgPool3dBackward<T> {
1949    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1950        if !self.input.requires_grad() {
1951            return Ok(vec![None]);
1952        }
1953
1954        let go_data = grad_output.data_vec()?;
1955        let in_shape = self.input.shape();
1956        let (batch, channels, d, h, w) = (
1957            in_shape[0],
1958            in_shape[1],
1959            in_shape[2],
1960            in_shape[3],
1961            in_shape[4],
1962        );
1963        let (out_d, out_h, out_w) = self.output_size;
1964
1965        let mut grad_input = vec![<T as num_traits::Zero>::zero(); batch * channels * d * h * w];
1966
1967        for b in 0..batch {
1968            for c in 0..channels {
1969                for od in 0..out_d {
1970                    let d_start = adaptive_start(od, d, out_d);
1971                    let d_end = adaptive_end(od, d, out_d);
1972
1973                    for oh in 0..out_h {
1974                        let h_start = adaptive_start(oh, h, out_h);
1975                        let h_end = adaptive_end(oh, h, out_h);
1976
1977                        for ow in 0..out_w {
1978                            let w_start = adaptive_start(ow, w, out_w);
1979                            let w_end = adaptive_end(ow, w, out_w);
1980
1981                            let window_vol =
1982                                (d_end - d_start) * (h_end - h_start) * (w_end - w_start);
1983                            let out_idx =
1984                                (((b * channels + c) * out_d + od) * out_h + oh) * out_w + ow;
1985                            let grad_val = go_data[out_idx] / T::from(window_vol).unwrap();
1986
1987                            for id in d_start..d_end {
1988                                for ih in h_start..h_end {
1989                                    for iw in w_start..w_end {
1990                                        let in_idx =
1991                                            (((b * channels + c) * d + id) * h + ih) * w + iw;
1992                                        grad_input[in_idx] += grad_val;
1993                                    }
1994                                }
1995                            }
1996                        }
1997                    }
1998                }
1999            }
2000        }
2001
2002        let grad_tensor = Tensor::from_storage(
2003            TensorStorage::cpu(grad_input),
2004            self.input.shape().to_vec(),
2005            false,
2006        )?;
2007        Ok(vec![Some(grad_tensor)])
2008    }
2009
2010    fn inputs(&self) -> Vec<&Tensor<T>> {
2011        vec![&self.input]
2012    }
2013
2014    fn name(&self) -> &'static str {
2015        "AdaptiveAvgPool3dBackward"
2016    }
2017}
2018
2019// ===========================================================================
2020// MaxUnpool2d — CL-315
2021// ===========================================================================
2022
2023/// Inverse of `MaxPool2d`.
2024///
2025/// Given an output from `MaxPool2d` and the indices of the max positions,
2026/// scatters the values back into an output tensor of the specified
2027/// `output_size`. Positions not pointed to by any index remain zero.
2028///
2029/// This is commonly used in encoder-decoder architectures (e.g. SegNet)
2030/// where the pooling indices from the encoder are reused in the decoder.
2031///
2032/// Input shape: `[B, C, H, W]` (the pooled tensor)
2033/// Output shape: `[B, C, output_size.0, output_size.1]`
2034#[derive(Debug, Clone)]
2035pub struct MaxUnpool2d {
2036    pub kernel_size: [usize; 2],
2037    pub stride: [usize; 2],
2038    pub padding: [usize; 2],
2039}
2040
2041impl MaxUnpool2d {
2042    /// Create a new `MaxUnpool2d` layer.
2043    ///
2044    /// `stride` defaults to `kernel_size` when set to `[0, 0]`.
2045    pub fn new(kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2]) -> Self {
2046        let stride = if stride == [0, 0] {
2047            kernel_size
2048        } else {
2049            stride
2050        };
2051        Self {
2052            kernel_size,
2053            stride,
2054            padding,
2055        }
2056    }
2057}
2058
2059impl MaxUnpool2d {
2060    /// Forward pass for `MaxUnpool2d`.
2061    ///
2062    /// # Arguments
2063    ///
2064    /// * `input` - The pooled tensor, shape `[B, C, H, W]`.
2065    /// * `indices` - Flat indices from the corresponding `MaxPool2d` forward.
2066    /// * `output_size` - The desired spatial output size `(H_out, W_out)`.
2067    pub fn forward_with_indices<T: Float>(
2068        &self,
2069        input: &Tensor<T>,
2070        indices: &[usize],
2071        output_size: (usize, usize),
2072    ) -> FerrotorchResult<Tensor<T>> {
2073        max_unpool2d_forward(input, indices, output_size)
2074    }
2075}
2076
2077/// Functional API for `MaxUnpool2d`.
2078///
2079/// Scatters `input` values into an output of shape
2080/// `[B, C, output_size.0, output_size.1]` using the given flat `indices`.
2081pub fn max_unpool2d<T: Float>(
2082    input: &Tensor<T>,
2083    indices: &[usize],
2084    output_size: (usize, usize),
2085) -> FerrotorchResult<Tensor<T>> {
2086    max_unpool2d_forward(input, indices, output_size)
2087}
2088
2089/// Forward computation for max unpooling.
2090fn max_unpool2d_forward<T: Float>(
2091    input: &Tensor<T>,
2092    indices: &[usize],
2093    output_size: (usize, usize),
2094) -> FerrotorchResult<Tensor<T>> {
2095    let (batch, channels, _h, _w) = validate_4d(input)?;
2096    let (out_h, out_w) = output_size;
2097
2098    if input.numel() != indices.len() {
2099        return Err(FerrotorchError::InvalidArgument {
2100            message: format!(
2101                "MaxUnpool2d: input numel ({}) != indices len ({})",
2102                input.numel(),
2103                indices.len()
2104            ),
2105        });
2106    }
2107
2108    let input_device = input.device();
2109    let data = input.data_vec()?;
2110    let output_numel = batch * channels * out_h * out_w;
2111    let mut output = vec![<T as num_traits::Zero>::zero(); output_numel];
2112
2113    // Scatter values to the positions indicated by indices.
2114    for (i, &idx) in indices.iter().enumerate() {
2115        if idx >= output_numel {
2116            return Err(FerrotorchError::InvalidArgument {
2117                message: format!(
2118                    "MaxUnpool2d: index {} out of bounds for output size {}",
2119                    idx, output_numel
2120                ),
2121            });
2122        }
2123        output[idx] = data[i];
2124    }
2125
2126    let out_shape = vec![batch, channels, out_h, out_w];
2127    let storage = TensorStorage::cpu(output);
2128
2129    if is_grad_enabled() && input.requires_grad() {
2130        Tensor::from_operation(
2131            storage,
2132            out_shape,
2133            Arc::new(MaxUnpool2dBackward {
2134                input: input.clone(),
2135                indices: indices.to_vec(),
2136            }),
2137        )?
2138        .to(input_device)
2139    } else {
2140        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
2141    }
2142}
2143
2144/// Backward for `MaxUnpool2d`.
2145///
2146/// The gradient simply gathers values from the upstream gradient at the
2147/// index positions (the reverse of the scatter in forward).
2148#[derive(Debug)]
2149struct MaxUnpool2dBackward<T: Float> {
2150    input: Tensor<T>,
2151    indices: Vec<usize>,
2152}
2153
2154impl<T: Float> GradFn<T> for MaxUnpool2dBackward<T> {
2155    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2156        if !self.input.requires_grad() {
2157            return Ok(vec![None]);
2158        }
2159
2160        let go_data = grad_output.data_vec()?;
2161        let mut grad_input = vec![<T as num_traits::Zero>::zero(); self.input.numel()];
2162
2163        // Gather: grad_input[i] = grad_output[indices[i]]
2164        for (i, &idx) in self.indices.iter().enumerate() {
2165            grad_input[i] = go_data[idx];
2166        }
2167
2168        let grad_tensor = Tensor::from_storage(
2169            TensorStorage::cpu(grad_input),
2170            self.input.shape().to_vec(),
2171            false,
2172        )?;
2173        Ok(vec![Some(grad_tensor)])
2174    }
2175
2176    fn inputs(&self) -> Vec<&Tensor<T>> {
2177        vec![&self.input]
2178    }
2179
2180    fn name(&self) -> &'static str {
2181        "MaxUnpool2dBackward"
2182    }
2183}
2184
2185// ===========================================================================
2186// AdaptiveMaxPool1d — CL-432
2187// ===========================================================================
2188
2189/// 1D adaptive max pooling layer.
2190///
2191/// Dynamically computes window boundaries to produce the target `output_size`
2192/// regardless of input spatial dimensions. Returns the pooled tensor and
2193/// stores indices internally for gradient routing.
2194///
2195/// Input shape: `[B, C, L]`
2196/// Output shape: `[B, C, output_size]`
2197#[derive(Debug, Clone)]
2198pub struct AdaptiveMaxPool1d {
2199    pub output_size: usize,
2200}
2201
2202impl AdaptiveMaxPool1d {
2203    /// Create a new `AdaptiveMaxPool1d` targeting the given output length.
2204    pub fn new(output_size: usize) -> Self {
2205        Self { output_size }
2206    }
2207}
2208
2209impl<T: Float> Module<T> for AdaptiveMaxPool1d {
2210    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2211        adaptive_max_pool1d_forward(input, self.output_size)
2212    }
2213
2214    fn parameters(&self) -> Vec<&Parameter<T>> {
2215        vec![]
2216    }
2217
2218    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2219        vec![]
2220    }
2221
2222    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2223        vec![]
2224    }
2225
2226    fn train(&mut self) {}
2227    fn eval(&mut self) {}
2228
2229    fn is_training(&self) -> bool {
2230        false
2231    }
2232}
2233
2234/// Forward computation for 1D adaptive max pooling.
2235fn adaptive_max_pool1d_forward<T: Float>(
2236    input: &Tensor<T>,
2237    output_size: usize,
2238) -> FerrotorchResult<Tensor<T>> {
2239    let (batch, channels, l) = validate_3d(input)?;
2240    let out_l = output_size;
2241
2242    if out_l == 0 {
2243        return Err(FerrotorchError::InvalidArgument {
2244            message: "adaptive output_size must be > 0".into(),
2245        });
2246    }
2247
2248    let input_device = input.device();
2249    let data = input.data_vec()?;
2250    let total = batch * channels * out_l;
2251    let mut output = vec![<T as num_traits::Zero>::zero(); total];
2252    let mut indices = vec![0usize; total];
2253    let neg_inf = T::from(-1e38).unwrap();
2254
2255    for b in 0..batch {
2256        for c in 0..channels {
2257            for ol in 0..out_l {
2258                let l_start = adaptive_start(ol, l, out_l);
2259                let l_end = adaptive_end(ol, l, out_l);
2260
2261                let out_idx = (b * channels + c) * out_l + ol;
2262                let mut max_val = neg_inf;
2263                let mut max_idx = 0usize;
2264
2265                for il in l_start..l_end {
2266                    let in_idx = (b * channels + c) * l + il;
2267                    let val = data[in_idx];
2268                    if val > max_val {
2269                        max_val = val;
2270                        max_idx = in_idx;
2271                    }
2272                }
2273
2274                output[out_idx] = max_val;
2275                indices[out_idx] = max_idx;
2276            }
2277        }
2278    }
2279
2280    let out_shape = vec![batch, channels, out_l];
2281    let storage = TensorStorage::cpu(output);
2282
2283    if is_grad_enabled() && input.requires_grad() {
2284        Tensor::from_operation(
2285            storage,
2286            out_shape,
2287            Arc::new(AdaptiveMaxPool1dBackward {
2288                input: input.clone(),
2289                indices,
2290            }),
2291        )?
2292        .to(input_device)
2293    } else {
2294        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
2295    }
2296}
2297
2298/// Backward for `AdaptiveMaxPool1d`.
2299#[derive(Debug)]
2300struct AdaptiveMaxPool1dBackward<T: Float> {
2301    input: Tensor<T>,
2302    indices: Vec<usize>,
2303}
2304
2305impl<T: Float> GradFn<T> for AdaptiveMaxPool1dBackward<T> {
2306    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2307        if !self.input.requires_grad() {
2308            return Ok(vec![None]);
2309        }
2310
2311        let go_data = grad_output.data_vec()?;
2312        let input_numel = self.input.numel();
2313        let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_numel];
2314
2315        for (out_idx, &in_idx) in self.indices.iter().enumerate() {
2316            grad_input[in_idx] += go_data[out_idx];
2317        }
2318
2319        let grad_tensor = Tensor::from_storage(
2320            TensorStorage::cpu(grad_input),
2321            self.input.shape().to_vec(),
2322            false,
2323        )?;
2324        Ok(vec![Some(grad_tensor)])
2325    }
2326
2327    fn inputs(&self) -> Vec<&Tensor<T>> {
2328        vec![&self.input]
2329    }
2330
2331    fn name(&self) -> &'static str {
2332        "AdaptiveMaxPool1dBackward"
2333    }
2334}
2335
2336// ===========================================================================
2337// AdaptiveMaxPool3d — CL-432
2338// ===========================================================================
2339
2340/// 3D adaptive max pooling layer.
2341///
2342/// Dynamically computes window boundaries to produce the target `output_size`
2343/// regardless of input spatial dimensions.
2344///
2345/// Input shape: `[B, C, D, H, W]`
2346/// Output shape: `[B, C, output_size.0, output_size.1, output_size.2]`
2347#[derive(Debug, Clone)]
2348pub struct AdaptiveMaxPool3d {
2349    pub output_size: (usize, usize, usize),
2350}
2351
2352impl AdaptiveMaxPool3d {
2353    /// Create a new `AdaptiveMaxPool3d` targeting the given output spatial size.
2354    pub fn new(output_size: (usize, usize, usize)) -> Self {
2355        Self { output_size }
2356    }
2357}
2358
2359impl<T: Float> Module<T> for AdaptiveMaxPool3d {
2360    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2361        adaptive_max_pool3d_forward(input, self.output_size)
2362    }
2363
2364    fn parameters(&self) -> Vec<&Parameter<T>> {
2365        vec![]
2366    }
2367
2368    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2369        vec![]
2370    }
2371
2372    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2373        vec![]
2374    }
2375
2376    fn train(&mut self) {}
2377    fn eval(&mut self) {}
2378
2379    fn is_training(&self) -> bool {
2380        false
2381    }
2382}
2383
2384/// Forward computation for 3D adaptive max pooling.
2385fn adaptive_max_pool3d_forward<T: Float>(
2386    input: &Tensor<T>,
2387    output_size: (usize, usize, usize),
2388) -> FerrotorchResult<Tensor<T>> {
2389    let (batch, channels, d, h, w) = validate_5d(input)?;
2390    let (out_d, out_h, out_w) = output_size;
2391
2392    if out_d == 0 || out_h == 0 || out_w == 0 {
2393        return Err(FerrotorchError::InvalidArgument {
2394            message: "adaptive output_size must be > 0".into(),
2395        });
2396    }
2397
2398    let input_device = input.device();
2399    let data = input.data_vec()?;
2400    let total = batch * channels * out_d * out_h * out_w;
2401    let mut output = vec![<T as num_traits::Zero>::zero(); total];
2402    let mut indices = vec![0usize; total];
2403    let neg_inf = T::from(-1e38).unwrap();
2404
2405    for b in 0..batch {
2406        for c in 0..channels {
2407            for od in 0..out_d {
2408                let d_start = adaptive_start(od, d, out_d);
2409                let d_end = adaptive_end(od, d, out_d);
2410
2411                for oh in 0..out_h {
2412                    let h_start = adaptive_start(oh, h, out_h);
2413                    let h_end = adaptive_end(oh, h, out_h);
2414
2415                    for ow in 0..out_w {
2416                        let w_start = adaptive_start(ow, w, out_w);
2417                        let w_end = adaptive_end(ow, w, out_w);
2418
2419                        let out_idx = (((b * channels + c) * out_d + od) * out_h + oh) * out_w + ow;
2420                        let mut max_val = neg_inf;
2421                        let mut max_idx = 0usize;
2422
2423                        for id in d_start..d_end {
2424                            for ih in h_start..h_end {
2425                                for iw in w_start..w_end {
2426                                    let in_idx = (((b * channels + c) * d + id) * h + ih) * w + iw;
2427                                    let val = data[in_idx];
2428                                    if val > max_val {
2429                                        max_val = val;
2430                                        max_idx = in_idx;
2431                                    }
2432                                }
2433                            }
2434                        }
2435
2436                        output[out_idx] = max_val;
2437                        indices[out_idx] = max_idx;
2438                    }
2439                }
2440            }
2441        }
2442    }
2443
2444    let out_shape = vec![batch, channels, out_d, out_h, out_w];
2445    let storage = TensorStorage::cpu(output);
2446
2447    if is_grad_enabled() && input.requires_grad() {
2448        Tensor::from_operation(
2449            storage,
2450            out_shape,
2451            Arc::new(AdaptiveMaxPool3dBackward {
2452                input: input.clone(),
2453                indices,
2454            }),
2455        )?
2456        .to(input_device)
2457    } else {
2458        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
2459    }
2460}
2461
2462/// Backward for `AdaptiveMaxPool3d`.
2463#[derive(Debug)]
2464struct AdaptiveMaxPool3dBackward<T: Float> {
2465    input: Tensor<T>,
2466    indices: Vec<usize>,
2467}
2468
2469impl<T: Float> GradFn<T> for AdaptiveMaxPool3dBackward<T> {
2470    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2471        if !self.input.requires_grad() {
2472            return Ok(vec![None]);
2473        }
2474
2475        let go_data = grad_output.data_vec()?;
2476        let input_numel = self.input.numel();
2477        let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_numel];
2478
2479        for (out_idx, &in_idx) in self.indices.iter().enumerate() {
2480            grad_input[in_idx] += go_data[out_idx];
2481        }
2482
2483        let grad_tensor = Tensor::from_storage(
2484            TensorStorage::cpu(grad_input),
2485            self.input.shape().to_vec(),
2486            false,
2487        )?;
2488        Ok(vec![Some(grad_tensor)])
2489    }
2490
2491    fn inputs(&self) -> Vec<&Tensor<T>> {
2492        vec![&self.input]
2493    }
2494
2495    fn name(&self) -> &'static str {
2496        "AdaptiveMaxPool3dBackward"
2497    }
2498}
2499
2500// ===========================================================================
2501// FractionalMaxPool2d — CL-432
2502// ===========================================================================
2503
2504/// Fractional max pooling layer (2D).
2505///
2506/// Applies stochastic pooling as described in "Fractional Max-Pooling" by
2507/// Ben Graham. The output spatial dimensions are determined by `output_size`,
2508/// and the pooling regions are randomly (stochastically) chosen at each
2509/// forward pass during training, or deterministically in eval mode.
2510///
2511/// Input shape: `[B, C, H, W]`
2512/// Output shape: `[B, C, output_size.0, output_size.1]`
2513#[derive(Debug, Clone)]
2514pub struct FractionalMaxPool2d {
2515    pub output_size: (usize, usize),
2516}
2517
2518impl FractionalMaxPool2d {
2519    /// Create a new `FractionalMaxPool2d` targeting the given output spatial size.
2520    pub fn new(output_size: (usize, usize)) -> Self {
2521        Self { output_size }
2522    }
2523}
2524
2525/// Generate fractional pooling boundaries using a pseudo-random sequence.
2526///
2527/// Produces `output_size + 1` boundaries in `[0, input_size]` such that
2528/// the intervals cover the input and each interval length is either
2529/// `floor(input/output)` or `ceil(input/output)`.
2530fn fractional_boundaries(input_size: usize, output_size: usize, seed: u64) -> Vec<usize> {
2531    if output_size >= input_size {
2532        // Each output bin covers exactly one input position.
2533        return (0..=output_size).map(|i| i.min(input_size)).collect();
2534    }
2535
2536    let ratio = input_size as f64 / output_size as f64;
2537    let mut boundaries = Vec::with_capacity(output_size + 1);
2538    boundaries.push(0);
2539
2540    // Use the alpha sequence from the paper: generate a random alpha in [0,1)
2541    // per output bin to choose between floor and ceil kernel size.
2542    let mut rng_state = seed;
2543    for i in 0..output_size {
2544        // Advance the rng.
2545        rng_state ^= rng_state << 13;
2546        rng_state ^= rng_state >> 7;
2547        rng_state ^= rng_state << 17;
2548        let u = (rng_state as f64) / (u64::MAX as f64);
2549
2550        let ideal = (i + 1) as f64 * ratio;
2551        let boundary = if u < (ideal.ceil() - ideal) {
2552            ideal.floor() as usize
2553        } else {
2554            ideal.ceil() as usize
2555        };
2556        boundaries.push(boundary.min(input_size));
2557    }
2558
2559    // Ensure the last boundary is exactly input_size.
2560    *boundaries.last_mut().unwrap() = input_size;
2561    boundaries
2562}
2563
2564impl<T: Float> Module<T> for FractionalMaxPool2d {
2565    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2566        let (batch, channels, h, w) = validate_4d(input)?;
2567        let (out_h, out_w) = self.output_size;
2568
2569        if out_h == 0 || out_w == 0 {
2570            return Err(FerrotorchError::InvalidArgument {
2571                message: "FractionalMaxPool2d: output_size must be > 0".into(),
2572            });
2573        }
2574        if out_h > h || out_w > w {
2575            return Err(FerrotorchError::InvalidArgument {
2576                message: format!(
2577                    "FractionalMaxPool2d: output_size ({out_h}, {out_w}) must be <= input ({h}, {w})"
2578                ),
2579            });
2580        }
2581
2582        let input_device = input.device();
2583        let data = input.data_vec()?;
2584
2585        // Generate random boundaries using a per-forward seed.
2586        let seed = {
2587            use std::collections::hash_map::DefaultHasher;
2588            use std::hash::{Hash, Hasher};
2589            use std::time::SystemTime;
2590
2591            let mut hasher = DefaultHasher::new();
2592            SystemTime::now().hash(&mut hasher);
2593            std::thread::current().id().hash(&mut hasher);
2594            hasher.finish()
2595        };
2596
2597        let h_bounds = fractional_boundaries(h, out_h, seed);
2598        let w_bounds = fractional_boundaries(w, out_w, seed.wrapping_mul(2654435761));
2599
2600        let total = batch * channels * out_h * out_w;
2601        let mut output = vec![<T as num_traits::Zero>::zero(); total];
2602        let mut indices = vec![0usize; total];
2603        let neg_inf = T::from(-1e38).unwrap();
2604
2605        for b in 0..batch {
2606            for c in 0..channels {
2607                for oh in 0..out_h {
2608                    let h_start = h_bounds[oh];
2609                    let h_end = h_bounds[oh + 1];
2610
2611                    for ow in 0..out_w {
2612                        let w_start = w_bounds[ow];
2613                        let w_end = w_bounds[ow + 1];
2614
2615                        let out_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
2616                        let mut max_val = neg_inf;
2617                        let mut max_idx = 0usize;
2618
2619                        for ih in h_start..h_end {
2620                            for iw in w_start..w_end {
2621                                let in_idx = ((b * channels + c) * h + ih) * w + iw;
2622                                let val = data[in_idx];
2623                                if val > max_val {
2624                                    max_val = val;
2625                                    max_idx = in_idx;
2626                                }
2627                            }
2628                        }
2629
2630                        output[out_idx] = max_val;
2631                        indices[out_idx] = max_idx;
2632                    }
2633                }
2634            }
2635        }
2636
2637        let out_shape = vec![batch, channels, out_h, out_w];
2638        let storage = TensorStorage::cpu(output);
2639
2640        if is_grad_enabled() && input.requires_grad() {
2641            Tensor::from_operation(
2642                storage,
2643                out_shape,
2644                Arc::new(FractionalMaxPool2dBackward {
2645                    input: input.clone(),
2646                    indices,
2647                }),
2648            )?
2649            .to(input_device)
2650        } else {
2651            Tensor::from_storage(storage, out_shape, false)?.to(input_device)
2652        }
2653    }
2654
2655    fn parameters(&self) -> Vec<&Parameter<T>> {
2656        vec![]
2657    }
2658
2659    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2660        vec![]
2661    }
2662
2663    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2664        vec![]
2665    }
2666
2667    fn train(&mut self) {}
2668    fn eval(&mut self) {}
2669
2670    fn is_training(&self) -> bool {
2671        false
2672    }
2673}
2674
2675/// Backward for `FractionalMaxPool2d`.
2676#[derive(Debug)]
2677struct FractionalMaxPool2dBackward<T: Float> {
2678    input: Tensor<T>,
2679    indices: Vec<usize>,
2680}
2681
2682impl<T: Float> GradFn<T> for FractionalMaxPool2dBackward<T> {
2683    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2684        if !self.input.requires_grad() {
2685            return Ok(vec![None]);
2686        }
2687
2688        let go_data = grad_output.data_vec()?;
2689        let input_numel = self.input.numel();
2690        let mut grad_input = vec![<T as num_traits::Zero>::zero(); input_numel];
2691
2692        for (out_idx, &in_idx) in self.indices.iter().enumerate() {
2693            grad_input[in_idx] += go_data[out_idx];
2694        }
2695
2696        let grad_tensor = Tensor::from_storage(
2697            TensorStorage::cpu(grad_input),
2698            self.input.shape().to_vec(),
2699            false,
2700        )?;
2701        Ok(vec![Some(grad_tensor)])
2702    }
2703
2704    fn inputs(&self) -> Vec<&Tensor<T>> {
2705        vec![&self.input]
2706    }
2707
2708    fn name(&self) -> &'static str {
2709        "FractionalMaxPool2dBackward"
2710    }
2711}
2712
2713// ===========================================================================
2714// LPPool1d — CL-432
2715// ===========================================================================
2716
2717/// 1D Lp norm pooling layer.
2718///
2719/// Computes `(sum(|x|^p) over kernel)^(1/p)` for each pooling window.
2720///
2721/// Input shape: `[B, C, L]`
2722/// Output shape: `[B, C, L_out]`
2723///
2724/// When `p == 1`, this is equivalent to average pooling (of absolute values).
2725/// When `p == 2`, this is the L2 (Euclidean) norm pooling.
2726///
2727/// Matches `torch.nn.LPPool1d`.
2728#[derive(Debug, Clone)]
2729pub struct LPPool1d {
2730    pub norm_type: f64,
2731    pub kernel_size: usize,
2732    pub stride: usize,
2733}
2734
2735impl LPPool1d {
2736    /// Create a new `LPPool1d` layer.
2737    ///
2738    /// # Arguments
2739    ///
2740    /// * `norm_type` - The exponent `p` for the Lp norm.
2741    /// * `kernel_size` - Size of the pooling window.
2742    /// * `stride` - Stride of the pooling window. If `0`, defaults to `kernel_size`.
2743    pub fn new(norm_type: f64, kernel_size: usize, stride: usize) -> Self {
2744        let stride = if stride == 0 { kernel_size } else { stride };
2745        Self {
2746            norm_type,
2747            kernel_size,
2748            stride,
2749        }
2750    }
2751}
2752
2753impl<T: Float> Module<T> for LPPool1d {
2754    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2755        lp_pool1d_forward(input, self.norm_type, self.kernel_size, self.stride)
2756    }
2757
2758    fn parameters(&self) -> Vec<&Parameter<T>> {
2759        vec![]
2760    }
2761
2762    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2763        vec![]
2764    }
2765
2766    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2767        vec![]
2768    }
2769
2770    fn train(&mut self) {}
2771    fn eval(&mut self) {}
2772
2773    fn is_training(&self) -> bool {
2774        false
2775    }
2776}
2777
2778/// Forward computation for 1D Lp norm pooling.
2779fn lp_pool1d_forward<T: Float>(
2780    input: &Tensor<T>,
2781    norm_type: f64,
2782    kernel_size: usize,
2783    stride: usize,
2784) -> FerrotorchResult<Tensor<T>> {
2785    let (batch, channels, l) = validate_3d(input)?;
2786    let out_l = validate_pool_params_1d(l, kernel_size, stride, 0)?;
2787    let p_t = T::from(norm_type).unwrap();
2788    let inv_p = T::from(1.0 / norm_type).unwrap();
2789
2790    let input_device = input.device();
2791    let data = input.data_vec()?;
2792    let total = batch * channels * out_l;
2793    let mut output = vec![<T as num_traits::Zero>::zero(); total];
2794
2795    for b in 0..batch {
2796        for c in 0..channels {
2797            for ol in 0..out_l {
2798                let l_start = ol * stride;
2799                let l_end = (l_start + kernel_size).min(l);
2800
2801                let mut sum = <T as num_traits::Zero>::zero();
2802                for il in l_start..l_end {
2803                    let in_idx = (b * channels + c) * l + il;
2804                    sum += data[in_idx].abs().powf(p_t);
2805                }
2806
2807                let out_idx = (b * channels + c) * out_l + ol;
2808                output[out_idx] = sum.powf(inv_p);
2809            }
2810        }
2811    }
2812
2813    let out_shape = vec![batch, channels, out_l];
2814    let storage = TensorStorage::cpu(output);
2815
2816    if is_grad_enabled() && input.requires_grad() {
2817        Tensor::from_operation(
2818            storage,
2819            out_shape,
2820            Arc::new(LPPool1dBackward {
2821                input: input.clone(),
2822                norm_type,
2823                kernel_size,
2824                stride,
2825            }),
2826        )?
2827        .to(input_device)
2828    } else {
2829        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
2830    }
2831}
2832
2833/// Backward for `LPPool1d`.
2834///
2835/// For `y = (sum |x_i|^p)^(1/p)`, the gradient is:
2836/// `dy/dx_i = y^(1-p) * |x_i|^(p-1) * sign(x_i)`
2837/// which is equivalent to `(|x_i|/y)^(p-1) * sign(x_i) / y^0 ` simplified to:
2838/// `dy/dx_i = x_i * |x_i|^(p-2) / y^(p-1)`
2839#[derive(Debug)]
2840struct LPPool1dBackward<T: Float> {
2841    input: Tensor<T>,
2842    norm_type: f64,
2843    kernel_size: usize,
2844    stride: usize,
2845}
2846
2847impl<T: Float> GradFn<T> for LPPool1dBackward<T> {
2848    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2849        if !self.input.requires_grad() {
2850            return Ok(vec![None]);
2851        }
2852
2853        let in_shape = self.input.shape();
2854        let (batch, channels, l) = (in_shape[0], in_shape[1], in_shape[2]);
2855        let out_l = (l - self.kernel_size) / self.stride + 1;
2856
2857        let input_data = self.input.data_vec()?;
2858        let go_data = grad_output.data_vec()?;
2859
2860        let p_t = T::from(self.norm_type).unwrap();
2861        let inv_p = T::from(1.0 / self.norm_type).unwrap();
2862        let p_minus_1 = T::from(self.norm_type - 1.0).unwrap();
2863        let p_minus_2 = T::from(self.norm_type - 2.0).unwrap();
2864        let eps = T::from(1e-12).unwrap();
2865
2866        let mut grad_input = vec![<T as num_traits::Zero>::zero(); self.input.numel()];
2867
2868        for b in 0..batch {
2869            for c in 0..channels {
2870                for ol in 0..out_l {
2871                    let l_start = ol * self.stride;
2872                    let l_end = (l_start + self.kernel_size).min(l);
2873
2874                    // Recompute the output value for this window.
2875                    let mut sum = <T as num_traits::Zero>::zero();
2876                    for il in l_start..l_end {
2877                        let in_idx = (b * channels + c) * l + il;
2878                        sum += input_data[in_idx].abs().powf(p_t);
2879                    }
2880                    let y = sum.powf(inv_p);
2881                    let y_p_minus_1 = y.powf(p_minus_1) + eps;
2882
2883                    let out_idx = (b * channels + c) * out_l + ol;
2884                    let go = go_data[out_idx];
2885
2886                    for il in l_start..l_end {
2887                        let in_idx = (b * channels + c) * l + il;
2888                        let x = input_data[in_idx];
2889                        // dy/dx_i = x_i * |x_i|^(p-2) / y^(p-1)
2890                        let grad_val = x * x.abs().powf(p_minus_2) / y_p_minus_1;
2891                        grad_input[in_idx] += go * grad_val;
2892                    }
2893                }
2894            }
2895        }
2896
2897        let grad_tensor = Tensor::from_storage(
2898            TensorStorage::cpu(grad_input),
2899            self.input.shape().to_vec(),
2900            false,
2901        )?;
2902        Ok(vec![Some(grad_tensor)])
2903    }
2904
2905    fn inputs(&self) -> Vec<&Tensor<T>> {
2906        vec![&self.input]
2907    }
2908
2909    fn name(&self) -> &'static str {
2910        "LPPool1dBackward"
2911    }
2912}
2913
2914// ===========================================================================
2915// LPPool2d — CL-432
2916// ===========================================================================
2917
2918/// 2D Lp norm pooling layer.
2919///
2920/// Computes `(sum(|x|^p) over kernel)^(1/p)` for each pooling window.
2921///
2922/// Input shape: `[B, C, H, W]`
2923/// Output shape: `[B, C, H_out, W_out]`
2924///
2925/// Matches `torch.nn.LPPool2d`.
2926#[derive(Debug, Clone)]
2927pub struct LPPool2d {
2928    pub norm_type: f64,
2929    pub kernel_size: [usize; 2],
2930    pub stride: [usize; 2],
2931}
2932
2933impl LPPool2d {
2934    /// Create a new `LPPool2d` layer.
2935    ///
2936    /// # Arguments
2937    ///
2938    /// * `norm_type` - The exponent `p` for the Lp norm.
2939    /// * `kernel_size` - Size of the pooling window `[kH, kW]`.
2940    /// * `stride` - Stride of the pooling window `[sH, sW]`. Elements of `0` default to corresponding kernel_size.
2941    pub fn new(norm_type: f64, kernel_size: [usize; 2], stride: [usize; 2]) -> Self {
2942        let stride = [
2943            if stride[0] == 0 {
2944                kernel_size[0]
2945            } else {
2946                stride[0]
2947            },
2948            if stride[1] == 0 {
2949                kernel_size[1]
2950            } else {
2951                stride[1]
2952            },
2953        ];
2954        Self {
2955            norm_type,
2956            kernel_size,
2957            stride,
2958        }
2959    }
2960}
2961
2962impl<T: Float> Module<T> for LPPool2d {
2963    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2964        lp_pool2d_forward(input, self.norm_type, self.kernel_size, self.stride)
2965    }
2966
2967    fn parameters(&self) -> Vec<&Parameter<T>> {
2968        vec![]
2969    }
2970
2971    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2972        vec![]
2973    }
2974
2975    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2976        vec![]
2977    }
2978
2979    fn train(&mut self) {}
2980    fn eval(&mut self) {}
2981
2982    fn is_training(&self) -> bool {
2983        false
2984    }
2985}
2986
2987/// Forward computation for 2D Lp norm pooling.
2988fn lp_pool2d_forward<T: Float>(
2989    input: &Tensor<T>,
2990    norm_type: f64,
2991    kernel_size: [usize; 2],
2992    stride: [usize; 2],
2993) -> FerrotorchResult<Tensor<T>> {
2994    let (batch, channels, h, w) = validate_4d(input)?;
2995    let out_h = validate_pool_params_1d(h, kernel_size[0], stride[0], 0)?;
2996    let out_w = validate_pool_params_1d(w, kernel_size[1], stride[1], 0)?;
2997    let p_t = T::from(norm_type).unwrap();
2998    let inv_p = T::from(1.0 / norm_type).unwrap();
2999
3000    let input_device = input.device();
3001    let data = input.data_vec()?;
3002    let total = batch * channels * out_h * out_w;
3003    let mut output = vec![<T as num_traits::Zero>::zero(); total];
3004
3005    for b in 0..batch {
3006        for c in 0..channels {
3007            for oh in 0..out_h {
3008                let h_start = oh * stride[0];
3009                let h_end = (h_start + kernel_size[0]).min(h);
3010
3011                for ow in 0..out_w {
3012                    let w_start = ow * stride[1];
3013                    let w_end = (w_start + kernel_size[1]).min(w);
3014
3015                    let mut sum = <T as num_traits::Zero>::zero();
3016                    for ih in h_start..h_end {
3017                        for iw in w_start..w_end {
3018                            let in_idx = ((b * channels + c) * h + ih) * w + iw;
3019                            sum += data[in_idx].abs().powf(p_t);
3020                        }
3021                    }
3022
3023                    let out_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
3024                    output[out_idx] = sum.powf(inv_p);
3025                }
3026            }
3027        }
3028    }
3029
3030    let out_shape = vec![batch, channels, out_h, out_w];
3031    let storage = TensorStorage::cpu(output);
3032
3033    if is_grad_enabled() && input.requires_grad() {
3034        Tensor::from_operation(
3035            storage,
3036            out_shape,
3037            Arc::new(LPPool2dBackward {
3038                input: input.clone(),
3039                norm_type,
3040                kernel_size,
3041                stride,
3042            }),
3043        )?
3044        .to(input_device)
3045    } else {
3046        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
3047    }
3048}
3049
3050/// Backward for `LPPool2d`.
3051#[derive(Debug)]
3052struct LPPool2dBackward<T: Float> {
3053    input: Tensor<T>,
3054    norm_type: f64,
3055    kernel_size: [usize; 2],
3056    stride: [usize; 2],
3057}
3058
3059impl<T: Float> GradFn<T> for LPPool2dBackward<T> {
3060    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
3061        if !self.input.requires_grad() {
3062            return Ok(vec![None]);
3063        }
3064
3065        let in_shape = self.input.shape();
3066        let (batch, channels, h, w) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
3067        let out_h = (h - self.kernel_size[0]) / self.stride[0] + 1;
3068        let out_w = (w - self.kernel_size[1]) / self.stride[1] + 1;
3069
3070        let input_data = self.input.data_vec()?;
3071        let go_data = grad_output.data_vec()?;
3072
3073        let p_t = T::from(self.norm_type).unwrap();
3074        let inv_p = T::from(1.0 / self.norm_type).unwrap();
3075        let p_minus_1 = T::from(self.norm_type - 1.0).unwrap();
3076        let p_minus_2 = T::from(self.norm_type - 2.0).unwrap();
3077        let eps = T::from(1e-12).unwrap();
3078
3079        let mut grad_input = vec![<T as num_traits::Zero>::zero(); self.input.numel()];
3080
3081        for b in 0..batch {
3082            for c in 0..channels {
3083                for oh in 0..out_h {
3084                    let h_start = oh * self.stride[0];
3085                    let h_end = (h_start + self.kernel_size[0]).min(h);
3086
3087                    for ow in 0..out_w {
3088                        let w_start = ow * self.stride[1];
3089                        let w_end = (w_start + self.kernel_size[1]).min(w);
3090
3091                        // Recompute output for this window.
3092                        let mut sum = <T as num_traits::Zero>::zero();
3093                        for ih in h_start..h_end {
3094                            for iw in w_start..w_end {
3095                                let in_idx = ((b * channels + c) * h + ih) * w + iw;
3096                                sum += input_data[in_idx].abs().powf(p_t);
3097                            }
3098                        }
3099                        let y = sum.powf(inv_p);
3100                        let y_p_minus_1 = y.powf(p_minus_1) + eps;
3101
3102                        let out_idx = ((b * channels + c) * out_h + oh) * out_w + ow;
3103                        let go = go_data[out_idx];
3104
3105                        for ih in h_start..h_end {
3106                            for iw in w_start..w_end {
3107                                let in_idx = ((b * channels + c) * h + ih) * w + iw;
3108                                let x = input_data[in_idx];
3109                                let grad_val = x * x.abs().powf(p_minus_2) / y_p_minus_1;
3110                                grad_input[in_idx] += go * grad_val;
3111                            }
3112                        }
3113                    }
3114                }
3115            }
3116        }
3117
3118        let grad_tensor = Tensor::from_storage(
3119            TensorStorage::cpu(grad_input),
3120            self.input.shape().to_vec(),
3121            false,
3122        )?;
3123        Ok(vec![Some(grad_tensor)])
3124    }
3125
3126    fn inputs(&self) -> Vec<&Tensor<T>> {
3127        vec![&self.input]
3128    }
3129
3130    fn name(&self) -> &'static str {
3131        "LPPool2dBackward"
3132    }
3133}
3134
3135// ===========================================================================
3136// Public functional API
3137// ===========================================================================
3138
3139/// Functional 1D max pooling. See [`MaxPool1d`] for details.
3140pub fn max_pool1d<T: Float>(
3141    input: &Tensor<T>,
3142    kernel_size: usize,
3143    stride: usize,
3144    padding: usize,
3145) -> FerrotorchResult<Tensor<T>> {
3146    max_pool1d_forward(input, kernel_size, stride, padding)
3147}
3148
3149/// Functional 2D max pooling. See [`MaxPool2d`] for details.
3150pub fn max_pool2d<T: Float>(
3151    input: &Tensor<T>,
3152    kernel_size: [usize; 2],
3153    stride: [usize; 2],
3154    padding: [usize; 2],
3155) -> FerrotorchResult<Tensor<T>> {
3156    max_pool2d_forward(input, kernel_size, stride, padding)
3157}
3158
3159/// Functional 3D max pooling. See [`MaxPool3d`] for details.
3160pub fn max_pool3d<T: Float>(
3161    input: &Tensor<T>,
3162    kernel_size: [usize; 3],
3163    stride: [usize; 3],
3164    padding: [usize; 3],
3165) -> FerrotorchResult<Tensor<T>> {
3166    max_pool3d_forward(input, kernel_size, stride, padding)
3167}
3168
3169/// Functional 1D average pooling. See [`AvgPool1d`] for details.
3170pub fn avg_pool1d<T: Float>(
3171    input: &Tensor<T>,
3172    kernel_size: usize,
3173    stride: usize,
3174    padding: usize,
3175) -> FerrotorchResult<Tensor<T>> {
3176    avg_pool1d_forward(input, kernel_size, stride, padding)
3177}
3178
3179/// Functional 2D average pooling. See [`AvgPool2d`] for details.
3180pub fn avg_pool2d<T: Float>(
3181    input: &Tensor<T>,
3182    kernel_size: [usize; 2],
3183    stride: [usize; 2],
3184    padding: [usize; 2],
3185) -> FerrotorchResult<Tensor<T>> {
3186    avg_pool2d_forward(input, kernel_size, stride, padding)
3187}
3188
3189/// Functional 3D average pooling. See [`AvgPool3d`] for details.
3190pub fn avg_pool3d<T: Float>(
3191    input: &Tensor<T>,
3192    kernel_size: [usize; 3],
3193    stride: [usize; 3],
3194    padding: [usize; 3],
3195) -> FerrotorchResult<Tensor<T>> {
3196    avg_pool3d_forward(input, kernel_size, stride, padding)
3197}
3198
3199/// Functional 1D adaptive average pooling. See [`AdaptiveAvgPool1d`] for details.
3200pub fn adaptive_avg_pool1d<T: Float>(
3201    input: &Tensor<T>,
3202    output_size: usize,
3203) -> FerrotorchResult<Tensor<T>> {
3204    adaptive_avg_pool1d_forward(input, output_size)
3205}
3206
3207/// Functional 2D adaptive average pooling. See [`AdaptiveAvgPool2d`] for details.
3208pub fn adaptive_avg_pool2d<T: Float>(
3209    input: &Tensor<T>,
3210    output_size: (usize, usize),
3211) -> FerrotorchResult<Tensor<T>> {
3212    adaptive_avg_pool2d_forward(input, output_size)
3213}
3214
3215/// Functional 3D adaptive average pooling. See [`AdaptiveAvgPool3d`] for details.
3216pub fn adaptive_avg_pool3d<T: Float>(
3217    input: &Tensor<T>,
3218    output_size: (usize, usize, usize),
3219) -> FerrotorchResult<Tensor<T>> {
3220    adaptive_avg_pool3d_forward(input, output_size)
3221}
3222
3223/// Functional 2D adaptive max pooling. See [`AdaptiveMaxPool2d`] for details.
3224pub fn adaptive_max_pool2d<T: Float>(
3225    input: &Tensor<T>,
3226    output_size: (usize, usize),
3227) -> FerrotorchResult<Tensor<T>> {
3228    adaptive_max_pool2d_forward(input, output_size)
3229}
3230
3231/// Functional 1D adaptive max pooling. See [`AdaptiveMaxPool1d`] for details.
3232pub fn adaptive_max_pool1d<T: Float>(
3233    input: &Tensor<T>,
3234    output_size: usize,
3235) -> FerrotorchResult<Tensor<T>> {
3236    adaptive_max_pool1d_forward(input, output_size)
3237}
3238
3239/// Functional 3D adaptive max pooling. See [`AdaptiveMaxPool3d`] for details.
3240pub fn adaptive_max_pool3d<T: Float>(
3241    input: &Tensor<T>,
3242    output_size: (usize, usize, usize),
3243) -> FerrotorchResult<Tensor<T>> {
3244    adaptive_max_pool3d_forward(input, output_size)
3245}
3246
3247/// Functional 1D Lp norm pooling. See [`LPPool1d`] for details.
3248pub fn lp_pool1d<T: Float>(
3249    input: &Tensor<T>,
3250    norm_type: f64,
3251    kernel_size: usize,
3252    stride: usize,
3253) -> FerrotorchResult<Tensor<T>> {
3254    lp_pool1d_forward(input, norm_type, kernel_size, stride)
3255}
3256
3257/// Functional 2D Lp norm pooling. See [`LPPool2d`] for details.
3258pub fn lp_pool2d<T: Float>(
3259    input: &Tensor<T>,
3260    norm_type: f64,
3261    kernel_size: [usize; 2],
3262    stride: [usize; 2],
3263) -> FerrotorchResult<Tensor<T>> {
3264    lp_pool2d_forward(input, norm_type, kernel_size, stride)
3265}
3266
3267// ===========================================================================
3268// Tests
3269// ===========================================================================
3270
3271#[cfg(test)]
3272mod tests {
3273    use super::*;
3274
3275    /// Create a leaf 4D tensor from flat data.
3276    fn leaf_4d(data: &[f32], shape: [usize; 4], requires_grad: bool) -> Tensor<f32> {
3277        Tensor::from_storage(
3278            TensorStorage::cpu(data.to_vec()),
3279            shape.to_vec(),
3280            requires_grad,
3281        )
3282        .unwrap()
3283    }
3284
3285    // -----------------------------------------------------------------------
3286    // MaxPool2d tests
3287    // -----------------------------------------------------------------------
3288
3289    #[test]
3290    fn test_maxpool2d_output_shape() {
3291        // Input: [1, 1, 4, 4], kernel 2x2, stride 2, no padding
3292        // Output: [1, 1, 2, 2]
3293        let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
3294        let pool = MaxPool2d::new([2, 2], [2, 2], [0, 0]);
3295        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3296        assert_eq!(out.shape(), &[1, 1, 2, 2]);
3297    }
3298
3299    #[test]
3300    fn test_maxpool2d_output_shape_with_padding() {
3301        // Input: [2, 3, 5, 5], kernel 3x3, stride 1, padding 1
3302        // H_out = (5 + 2*1 - 3) / 1 + 1 = 5
3303        let input = leaf_4d(&[0.0; 150], [2, 3, 5, 5], false);
3304        let pool = MaxPool2d::new([3, 3], [1, 1], [1, 1]);
3305        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3306        assert_eq!(out.shape(), &[2, 3, 5, 5]);
3307    }
3308
3309    #[test]
3310    fn test_maxpool2d_forward_correctness() {
3311        // Input [1, 1, 4, 4]:
3312        //  1  2  3  4
3313        //  5  6  7  8
3314        //  9 10 11 12
3315        // 13 14 15 16
3316        //
3317        // kernel 2x2, stride 2 => output [1, 1, 2, 2]:
3318        //  6  8
3319        // 14 16
3320        #[rustfmt::skip]
3321        let data: Vec<f32> = vec![
3322             1.0,  2.0,  3.0,  4.0,
3323             5.0,  6.0,  7.0,  8.0,
3324             9.0, 10.0, 11.0, 12.0,
3325            13.0, 14.0, 15.0, 16.0,
3326        ];
3327        let input = leaf_4d(&data, [1, 1, 4, 4], false);
3328        let pool = MaxPool2d::new([2, 2], [2, 2], [0, 0]);
3329        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3330        assert_eq!(out.data().unwrap(), &[6.0, 8.0, 14.0, 16.0]);
3331    }
3332
3333    #[test]
3334    fn test_maxpool2d_forward_stride1() {
3335        // Input [1, 1, 3, 3]:
3336        // 1 3 2
3337        // 4 6 5
3338        // 7 9 8
3339        //
3340        // kernel 2x2, stride 1 => output [1, 1, 2, 2]:
3341        //  max(1,3,4,6)=6  max(3,2,6,5)=6
3342        //  max(4,6,7,9)=9  max(6,5,9,8)=9
3343        #[rustfmt::skip]
3344        let data: Vec<f32> = vec![
3345            1.0, 3.0, 2.0,
3346            4.0, 6.0, 5.0,
3347            7.0, 9.0, 8.0,
3348        ];
3349        let input = leaf_4d(&data, [1, 1, 3, 3], false);
3350        let pool = MaxPool2d::new([2, 2], [1, 1], [0, 0]);
3351        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3352        assert_eq!(out.data().unwrap(), &[6.0, 6.0, 9.0, 9.0]);
3353    }
3354
3355    #[test]
3356    fn test_maxpool2d_backward() {
3357        // Input [1, 1, 4, 4], kernel 2x2, stride 2
3358        // Max indices: (1,1)=6, (1,3)=8, (3,1)=14, (3,3)=16
3359        #[rustfmt::skip]
3360        let data: Vec<f32> = vec![
3361             1.0,  2.0,  3.0,  4.0,
3362             5.0,  6.0,  7.0,  8.0,
3363             9.0, 10.0, 11.0, 12.0,
3364            13.0, 14.0, 15.0, 16.0,
3365        ];
3366        let input = leaf_4d(&data, [1, 1, 4, 4], true);
3367        let out = max_pool2d(&input, [2, 2], [2, 2], [0, 0]).unwrap();
3368
3369        // Manually construct a scalar loss = sum(out) for backward.
3370        let out_data = out.data().unwrap().to_vec();
3371        let total: f32 = out_data.iter().sum();
3372        let loss = Tensor::from_operation(
3373            TensorStorage::cpu(vec![total]),
3374            vec![],
3375            Arc::new(SumBackward { input: out }),
3376        )
3377        .unwrap();
3378        loss.backward().unwrap();
3379
3380        let grad = input.grad().unwrap().unwrap();
3381        let g = grad.data().unwrap();
3382        // Gradient should be 1.0 at max positions, 0.0 elsewhere.
3383        #[rustfmt::skip]
3384        let expected: Vec<f32> = vec![
3385            0.0, 0.0, 0.0, 0.0,
3386            0.0, 1.0, 0.0, 1.0,
3387            0.0, 0.0, 0.0, 0.0,
3388            0.0, 1.0, 0.0, 1.0,
3389        ];
3390        for (i, (&got, &exp)) in g.iter().zip(expected.iter()).enumerate() {
3391            assert!(
3392                (got - exp).abs() < 1e-6,
3393                "grad[{i}]: expected {exp}, got {got}"
3394            );
3395        }
3396    }
3397
3398    // -----------------------------------------------------------------------
3399    // AvgPool2d tests
3400    // -----------------------------------------------------------------------
3401
3402    #[test]
3403    fn test_avgpool2d_output_shape() {
3404        let input = leaf_4d(&[0.0; 48], [1, 3, 4, 4], false);
3405        let pool = AvgPool2d::new([2, 2], [2, 2], [0, 0]);
3406        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3407        assert_eq!(out.shape(), &[1, 3, 2, 2]);
3408    }
3409
3410    #[test]
3411    fn test_avgpool2d_forward_correctness() {
3412        // Input [1, 1, 4, 4]:
3413        //  1  2  3  4
3414        //  5  6  7  8
3415        //  9 10 11 12
3416        // 13 14 15 16
3417        //
3418        // kernel 2x2, stride 2 => output [1, 1, 2, 2]:
3419        //  avg(1,2,5,6)=3.5    avg(3,4,7,8)=5.5
3420        //  avg(9,10,13,14)=11.5 avg(11,12,15,16)=13.5
3421        #[rustfmt::skip]
3422        let data: Vec<f32> = vec![
3423             1.0,  2.0,  3.0,  4.0,
3424             5.0,  6.0,  7.0,  8.0,
3425             9.0, 10.0, 11.0, 12.0,
3426            13.0, 14.0, 15.0, 16.0,
3427        ];
3428        let input = leaf_4d(&data, [1, 1, 4, 4], false);
3429        let pool = AvgPool2d::new([2, 2], [2, 2], [0, 0]);
3430        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3431        let d = out.data().unwrap();
3432        assert!((d[0] - 3.5).abs() < 1e-6);
3433        assert!((d[1] - 5.5).abs() < 1e-6);
3434        assert!((d[2] - 11.5).abs() < 1e-6);
3435        assert!((d[3] - 13.5).abs() < 1e-6);
3436    }
3437
3438    #[test]
3439    fn test_avgpool2d_forward_stride1() {
3440        // Input [1, 1, 3, 3]:
3441        // 1 2 3
3442        // 4 5 6
3443        // 7 8 9
3444        //
3445        // kernel 2x2, stride 1 => output [1, 1, 2, 2]:
3446        //  avg(1,2,4,5)=3.0  avg(2,3,5,6)=4.0
3447        //  avg(4,5,7,8)=6.0  avg(5,6,8,9)=7.0
3448        #[rustfmt::skip]
3449        let data: Vec<f32> = vec![
3450            1.0, 2.0, 3.0,
3451            4.0, 5.0, 6.0,
3452            7.0, 8.0, 9.0,
3453        ];
3454        let input = leaf_4d(&data, [1, 1, 3, 3], false);
3455        let pool = AvgPool2d::new([2, 2], [1, 1], [0, 0]);
3456        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3457        let d = out.data().unwrap();
3458        assert!((d[0] - 3.0).abs() < 1e-6);
3459        assert!((d[1] - 4.0).abs() < 1e-6);
3460        assert!((d[2] - 6.0).abs() < 1e-6);
3461        assert!((d[3] - 7.0).abs() < 1e-6);
3462    }
3463
3464    #[test]
3465    fn test_avgpool2d_backward() {
3466        // Input [1, 1, 4, 4], kernel 2x2, stride 2
3467        // Each output element distributes grad / 4 to its 4 input positions.
3468        // With grad_output = all 1s, each input position that's covered gets 0.25.
3469        #[rustfmt::skip]
3470        let data: Vec<f32> = vec![
3471             1.0,  2.0,  3.0,  4.0,
3472             5.0,  6.0,  7.0,  8.0,
3473             9.0, 10.0, 11.0, 12.0,
3474            13.0, 14.0, 15.0, 16.0,
3475        ];
3476        let input = leaf_4d(&data, [1, 1, 4, 4], true);
3477        let out = avg_pool2d(&input, [2, 2], [2, 2], [0, 0]).unwrap();
3478
3479        let out_data = out.data().unwrap().to_vec();
3480        let total: f32 = out_data.iter().sum();
3481        let loss = Tensor::from_operation(
3482            TensorStorage::cpu(vec![total]),
3483            vec![],
3484            Arc::new(SumBackward { input: out }),
3485        )
3486        .unwrap();
3487        loss.backward().unwrap();
3488
3489        let grad = input.grad().unwrap().unwrap();
3490        let g = grad.data().unwrap();
3491        // Every input position is covered by exactly one window (stride = kernel_size).
3492        // grad = 1.0 / 4 = 0.25 for all positions.
3493        for (i, &val) in g.iter().enumerate() {
3494            assert!(
3495                (val - 0.25).abs() < 1e-6,
3496                "grad[{i}]: expected 0.25, got {val}"
3497            );
3498        }
3499    }
3500
3501    // -----------------------------------------------------------------------
3502    // AdaptiveAvgPool2d tests
3503    // -----------------------------------------------------------------------
3504
3505    #[test]
3506    fn test_adaptive_avgpool2d_output_shape() {
3507        let input = leaf_4d(&[0.0; 75], [1, 3, 5, 5], false);
3508        let pool = AdaptiveAvgPool2d::new((1, 1));
3509        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3510        assert_eq!(out.shape(), &[1, 3, 1, 1]);
3511    }
3512
3513    #[test]
3514    fn test_adaptive_avgpool2d_global() {
3515        // Global average pooling: output (1, 1) => mean of entire spatial plane.
3516        // Input [1, 1, 2, 2]: 1, 2, 3, 4 => mean = 2.5
3517        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
3518        let input = leaf_4d(&data, [1, 1, 2, 2], false);
3519        let pool = AdaptiveAvgPool2d::new((1, 1));
3520        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3521        assert_eq!(out.shape(), &[1, 1, 1, 1]);
3522        assert!((out.data().unwrap()[0] - 2.5).abs() < 1e-6);
3523    }
3524
3525    #[test]
3526    fn test_adaptive_avgpool2d_identity() {
3527        // Output size matches input => identity.
3528        #[rustfmt::skip]
3529        let data: Vec<f32> = vec![
3530            1.0, 2.0, 3.0,
3531            4.0, 5.0, 6.0,
3532            7.0, 8.0, 9.0,
3533        ];
3534        let input = leaf_4d(&data, [1, 1, 3, 3], false);
3535        let pool = AdaptiveAvgPool2d::new((3, 3));
3536        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3537        assert_eq!(out.shape(), &[1, 1, 3, 3]);
3538        let d = out.data().unwrap();
3539        for (i, (&got, &exp)) in d.iter().zip(data.iter()).enumerate() {
3540            assert!(
3541                (got - exp).abs() < 1e-6,
3542                "output[{i}]: expected {exp}, got {got}"
3543            );
3544        }
3545    }
3546
3547    #[test]
3548    fn test_adaptive_avgpool2d_2x2() {
3549        // Input [1, 1, 4, 4] => output (2, 2).
3550        // PyTorch adaptive formula:
3551        //   h_start(0) = 0*4/2=0, h_end(0) = ceil(1*4/2)=2
3552        //   h_start(1) = 1*4/2=2, h_end(1) = ceil(2*4/2)=4
3553        //   Same for w. So windows are [0..2, 0..2], [0..2, 2..4], [2..4, 0..2], [2..4, 2..4].
3554        #[rustfmt::skip]
3555        let data: Vec<f32> = vec![
3556             1.0,  2.0,  3.0,  4.0,
3557             5.0,  6.0,  7.0,  8.0,
3558             9.0, 10.0, 11.0, 12.0,
3559            13.0, 14.0, 15.0, 16.0,
3560        ];
3561        let input = leaf_4d(&data, [1, 1, 4, 4], false);
3562        let pool = AdaptiveAvgPool2d::new((2, 2));
3563        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3564        assert_eq!(out.shape(), &[1, 1, 2, 2]);
3565        let d = out.data().unwrap();
3566        // Window [0..2, 0..2]: avg(1,2,5,6) = 3.5
3567        assert!((d[0] - 3.5).abs() < 1e-6);
3568        // Window [0..2, 2..4]: avg(3,4,7,8) = 5.5
3569        assert!((d[1] - 5.5).abs() < 1e-6);
3570        // Window [2..4, 0..2]: avg(9,10,13,14) = 11.5
3571        assert!((d[2] - 11.5).abs() < 1e-6);
3572        // Window [2..4, 2..4]: avg(11,12,15,16) = 13.5
3573        assert!((d[3] - 13.5).abs() < 1e-6);
3574    }
3575
3576    #[test]
3577    fn test_adaptive_avgpool2d_backward() {
3578        // Input [1, 1, 4, 4] => output (1, 1) = global avg.
3579        // loss = output[0] (scalar).
3580        // d(loss)/d(input[i]) = 1/16 for all i.
3581        let data: Vec<f32> = (1..=16).map(|x| x as f32).collect();
3582        let input = leaf_4d(&data, [1, 1, 4, 4], true);
3583        let out = adaptive_avg_pool2d(&input, (1, 1)).unwrap();
3584
3585        // out is [1, 1, 1, 1], so item() works after reshape to scalar.
3586        let out_val = out.data().unwrap()[0];
3587        let loss = Tensor::from_operation(
3588            TensorStorage::cpu(vec![out_val]),
3589            vec![],
3590            Arc::new(SumBackward { input: out }),
3591        )
3592        .unwrap();
3593        loss.backward().unwrap();
3594
3595        let grad = input.grad().unwrap().unwrap();
3596        let g = grad.data().unwrap();
3597        let expected = 1.0 / 16.0;
3598        for (i, &val) in g.iter().enumerate() {
3599            assert!(
3600                (val - expected).abs() < 1e-6,
3601                "grad[{i}]: expected {expected}, got {val}"
3602            );
3603        }
3604    }
3605
3606    // -----------------------------------------------------------------------
3607    // Error handling tests
3608    // -----------------------------------------------------------------------
3609
3610    #[test]
3611    fn test_pooling_rejects_3d_input() {
3612        let input =
3613            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![0.0; 12]), vec![2, 3, 2], false)
3614                .unwrap();
3615        assert!(max_pool2d(&input, [2, 2], [1, 1], [0, 0]).is_err());
3616        assert!(avg_pool2d(&input, [2, 2], [1, 1], [0, 0]).is_err());
3617        assert!(adaptive_avg_pool2d(&input, (1, 1)).is_err());
3618    }
3619
3620    #[test]
3621    fn test_pooling_zero_kernel_rejected() {
3622        let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
3623        assert!(max_pool2d(&input, [0, 2], [1, 1], [0, 0]).is_err());
3624        assert!(avg_pool2d(&input, [2, 0], [1, 1], [0, 0]).is_err());
3625    }
3626
3627    #[test]
3628    fn test_pooling_zero_stride_defaults_to_kernel() {
3629        let _input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
3630        let pool = MaxPool2d::new([2, 2], [0, 0], [0, 0]);
3631        assert_eq!(pool.stride, [2, 2]);
3632    }
3633
3634    #[test]
3635    fn test_maxpool2d_zero_parameters() {
3636        let pool = MaxPool2d::new([2, 2], [2, 2], [0, 0]);
3637        let params: Vec<&Parameter<f32>> = Module::<f32>::parameters(&pool);
3638        assert!(params.is_empty());
3639    }
3640
3641    #[test]
3642    fn test_avgpool2d_zero_parameters() {
3643        let pool = AvgPool2d::new([2, 2], [2, 2], [0, 0]);
3644        let params: Vec<&Parameter<f32>> = Module::<f32>::parameters(&pool);
3645        assert!(params.is_empty());
3646    }
3647
3648    #[test]
3649    fn test_adaptive_avgpool2d_zero_parameters() {
3650        let pool = AdaptiveAvgPool2d::new((1, 1));
3651        let params: Vec<&Parameter<f32>> = Module::<f32>::parameters(&pool);
3652        assert!(params.is_empty());
3653    }
3654
3655    #[test]
3656    fn test_maxpool2d_batch_channels() {
3657        // Verify pooling works independently per batch/channel.
3658        // [2, 2, 4, 4] with known data.
3659        let mut data = Vec::with_capacity(64);
3660        for b in 0..2 {
3661            for c in 0..2 {
3662                let offset = (b * 2 + c) as f32 * 100.0;
3663                for i in 0..16 {
3664                    data.push(offset + i as f32);
3665                }
3666            }
3667        }
3668        let input = leaf_4d(&data, [2, 2, 4, 4], false);
3669        let out = max_pool2d(&input, [2, 2], [2, 2], [0, 0]).unwrap();
3670        assert_eq!(out.shape(), &[2, 2, 2, 2]);
3671
3672        let d = out.data().unwrap();
3673        // Batch 0, Channel 0: offset=0, max of [0..3,4..7] etc.
3674        assert!((d[0] - 5.0).abs() < 1e-6); // max(0,1,4,5)=5
3675        assert!((d[1] - 7.0).abs() < 1e-6); // max(2,3,6,7)=7
3676    }
3677
3678    // -----------------------------------------------------------------------
3679    // Helper backward node for tests
3680    // -----------------------------------------------------------------------
3681
3682    /// Sum reduction backward for test use.
3683    /// loss = sum(input); d(loss)/d(input_i) = 1.
3684    #[derive(Debug)]
3685    struct SumBackward<T: Float> {
3686        input: Tensor<T>,
3687    }
3688
3689    impl<T: Float> GradFn<T> for SumBackward<T> {
3690        fn backward(&self, _grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
3691            let ones_data = vec![<T as num_traits::One>::one(); self.input.numel()];
3692            let ones = Tensor::from_storage(
3693                TensorStorage::cpu(ones_data),
3694                self.input.shape().to_vec(),
3695                false,
3696            )?;
3697            Ok(vec![Some(ones)])
3698        }
3699
3700        fn inputs(&self) -> Vec<&Tensor<T>> {
3701            vec![&self.input]
3702        }
3703
3704        fn name(&self) -> &'static str {
3705            "SumBackward"
3706        }
3707    }
3708
3709    /// Create a leaf 3D tensor from flat data.
3710    fn leaf_3d(data: &[f32], shape: [usize; 3], requires_grad: bool) -> Tensor<f32> {
3711        Tensor::from_storage(
3712            TensorStorage::cpu(data.to_vec()),
3713            shape.to_vec(),
3714            requires_grad,
3715        )
3716        .unwrap()
3717    }
3718
3719    /// Create a leaf 5D tensor from flat data.
3720    fn leaf_5d(data: &[f32], shape: [usize; 5], requires_grad: bool) -> Tensor<f32> {
3721        Tensor::from_storage(
3722            TensorStorage::cpu(data.to_vec()),
3723            shape.to_vec(),
3724            requires_grad,
3725        )
3726        .unwrap()
3727    }
3728
3729    // -----------------------------------------------------------------------
3730    // AdaptiveMaxPool1d tests — CL-432
3731    // -----------------------------------------------------------------------
3732
3733    #[test]
3734    fn test_adaptive_max_pool1d_output_shape() {
3735        let pool = AdaptiveMaxPool1d::new(3);
3736        let input = leaf_3d(&[0.0; 20], [2, 2, 5], false);
3737        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3738        assert_eq!(out.shape(), &[2, 2, 3]);
3739    }
3740
3741    #[test]
3742    fn test_adaptive_max_pool1d_correctness() {
3743        // [1, 1, 6]: [1, 5, 3, 7, 2, 8]
3744        // output_size=3 => windows ~ [0,2), [2,4), [4,6)
3745        // max(1,5)=5, max(3,7)=7, max(2,8)=8
3746        let data: Vec<f32> = vec![1.0, 5.0, 3.0, 7.0, 2.0, 8.0];
3747        let input = leaf_3d(&data, [1, 1, 6], false);
3748        let out = adaptive_max_pool1d(&input, 3).unwrap();
3749        let d = out.data().unwrap();
3750        assert_eq!(d.len(), 3);
3751        assert!((d[0] - 5.0).abs() < 1e-6);
3752        assert!((d[1] - 7.0).abs() < 1e-6);
3753        assert!((d[2] - 8.0).abs() < 1e-6);
3754    }
3755
3756    #[test]
3757    fn test_adaptive_max_pool1d_backward() {
3758        let data: Vec<f32> = vec![1.0, 5.0, 3.0, 7.0, 2.0, 8.0];
3759        let input = leaf_3d(&data, [1, 1, 6], true);
3760        let out = adaptive_max_pool1d(&input, 3).unwrap();
3761
3762        let out_data = out.data().unwrap().to_vec();
3763        let total: f32 = out_data.iter().sum();
3764        let loss = Tensor::from_operation(
3765            TensorStorage::cpu(vec![total]),
3766            vec![],
3767            Arc::new(SumBackward { input: out.clone() }),
3768        )
3769        .unwrap();
3770        loss.backward().unwrap();
3771
3772        let grad = input.grad().unwrap().unwrap();
3773        assert_eq!(grad.shape(), &[1, 1, 6]);
3774        let gd = grad.data().unwrap();
3775        // Gradient should route to max positions only.
3776        // Max at idx 1 (val=5), idx 3 (val=7), idx 5 (val=8).
3777        assert!((gd[0]).abs() < 1e-6); // not max
3778        assert!((gd[1] - 1.0).abs() < 1e-6); // max
3779        assert!((gd[2]).abs() < 1e-6);
3780        assert!((gd[3] - 1.0).abs() < 1e-6); // max
3781        assert!((gd[4]).abs() < 1e-6);
3782        assert!((gd[5] - 1.0).abs() < 1e-6); // max
3783    }
3784
3785    #[test]
3786    fn test_adaptive_max_pool1d_zero_output_size() {
3787        let pool = AdaptiveMaxPool1d::new(0);
3788        let input = leaf_3d(&[1.0; 6], [1, 1, 6], false);
3789        assert!(Module::<f32>::forward(&pool, &input).is_err());
3790    }
3791
3792    // -----------------------------------------------------------------------
3793    // AdaptiveMaxPool3d tests — CL-432
3794    // -----------------------------------------------------------------------
3795
3796    #[test]
3797    fn test_adaptive_max_pool3d_output_shape() {
3798        let pool = AdaptiveMaxPool3d::new((2, 2, 2));
3799        let input = leaf_5d(&[0.0; 2 * 3 * 4 * 4 * 4], [2, 3, 4, 4, 4], false);
3800        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3801        assert_eq!(out.shape(), &[2, 3, 2, 2, 2]);
3802    }
3803
3804    #[test]
3805    fn test_adaptive_max_pool3d_correctness_single_output() {
3806        // Global max pool: output_size = (1, 1, 1).
3807        let mut data = vec![0.0f32; 2 * 2 * 2];
3808        data[5] = 10.0; // the max
3809        let input = leaf_5d(&data, [1, 1, 2, 2, 2], false);
3810        let out = adaptive_max_pool3d(&input, (1, 1, 1)).unwrap();
3811        assert_eq!(out.shape(), &[1, 1, 1, 1, 1]);
3812        assert!((out.data().unwrap()[0] - 10.0).abs() < 1e-6);
3813    }
3814
3815    #[test]
3816    fn test_adaptive_max_pool3d_backward() {
3817        let mut data = vec![1.0f32; 2 * 2 * 2 * 2];
3818        data[0] = 10.0; // max in first channel region
3819        let input = leaf_5d(&data, [1, 2, 2, 2, 2], true);
3820        let out = adaptive_max_pool3d(&input, (1, 1, 1)).unwrap();
3821
3822        let out_data = out.data().unwrap().to_vec();
3823        let total: f32 = out_data.iter().sum();
3824        let loss = Tensor::from_operation(
3825            TensorStorage::cpu(vec![total]),
3826            vec![],
3827            Arc::new(SumBackward { input: out.clone() }),
3828        )
3829        .unwrap();
3830        loss.backward().unwrap();
3831
3832        let grad = input.grad().unwrap().unwrap();
3833        assert_eq!(grad.shape(), &[1, 2, 2, 2, 2]);
3834    }
3835
3836    #[test]
3837    fn test_adaptive_max_pool3d_zero_output_size() {
3838        let pool = AdaptiveMaxPool3d::new((0, 1, 1));
3839        let input = leaf_5d(&[1.0; 8], [1, 1, 2, 2, 2], false);
3840        assert!(Module::<f32>::forward(&pool, &input).is_err());
3841    }
3842
3843    // -----------------------------------------------------------------------
3844    // FractionalMaxPool2d tests — CL-432
3845    // -----------------------------------------------------------------------
3846
3847    #[test]
3848    fn test_fractional_maxpool2d_output_shape() {
3849        let pool = FractionalMaxPool2d::new((3, 3));
3850        let input = leaf_4d(&[0.0; 2 * 3 * 8 * 8], [2, 3, 8, 8], false);
3851        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3852        assert_eq!(out.shape(), &[2, 3, 3, 3]);
3853    }
3854
3855    #[test]
3856    fn test_fractional_maxpool2d_values_from_input() {
3857        // All output values should be present in the input.
3858        #[rustfmt::skip]
3859        let data: Vec<f32> = (0..36).map(|i| i as f32).collect();
3860        let input = leaf_4d(&data, [1, 1, 6, 6], false);
3861        let pool = FractionalMaxPool2d::new((3, 3));
3862        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3863        let out_data = out.data().unwrap();
3864        for &v in out_data.iter() {
3865            assert!(data.contains(&v), "output value {v} not found in input");
3866        }
3867    }
3868
3869    #[test]
3870    fn test_fractional_maxpool2d_backward() {
3871        let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
3872        let input = leaf_4d(&data, [1, 1, 8, 8], true);
3873        let pool = FractionalMaxPool2d::new((4, 4));
3874        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3875
3876        let out_data = out.data().unwrap().to_vec();
3877        let total: f32 = out_data.iter().sum();
3878        let loss = Tensor::from_operation(
3879            TensorStorage::cpu(vec![total]),
3880            vec![],
3881            Arc::new(SumBackward { input: out.clone() }),
3882        )
3883        .unwrap();
3884        loss.backward().unwrap();
3885
3886        let grad = input.grad().unwrap().unwrap();
3887        assert_eq!(grad.shape(), &[1, 1, 8, 8]);
3888        // At least some positions should have non-zero gradient.
3889        let gd = grad.data().unwrap();
3890        let non_zero = gd.iter().filter(|&&g| g != 0.0).count();
3891        assert!(
3892            non_zero > 0,
3893            "backward should route gradient to max positions"
3894        );
3895    }
3896
3897    #[test]
3898    fn test_fractional_maxpool2d_output_larger_than_input() {
3899        let pool = FractionalMaxPool2d::new((5, 5));
3900        let input = leaf_4d(&[1.0; 16], [1, 1, 4, 4], false);
3901        assert!(Module::<f32>::forward(&pool, &input).is_err());
3902    }
3903
3904    #[test]
3905    fn test_fractional_maxpool2d_no_parameters() {
3906        let pool = FractionalMaxPool2d::new((3, 3));
3907        assert!(Module::<f32>::parameters(&pool).is_empty());
3908    }
3909
3910    // -----------------------------------------------------------------------
3911    // LPPool1d tests — CL-432
3912    // -----------------------------------------------------------------------
3913
3914    #[test]
3915    fn test_lppool1d_output_shape() {
3916        let pool = LPPool1d::new(2.0, 3, 2);
3917        let input = leaf_3d(&[0.0; 2 * 3 * 8], [2, 3, 8], false);
3918        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
3919        // out_l = (8 - 3) / 2 + 1 = 3
3920        assert_eq!(out.shape(), &[2, 3, 3]);
3921    }
3922
3923    #[test]
3924    fn test_lppool1d_l2_correctness() {
3925        // L2 pool with kernel=2, stride=2 over [3, 4] => sqrt(9+16) = 5
3926        let data: Vec<f32> = vec![3.0, 4.0, 1.0, 0.0];
3927        let input = leaf_3d(&data, [1, 1, 4], false);
3928        let out = lp_pool1d(&input, 2.0, 2, 2).unwrap();
3929        let d = out.data().unwrap();
3930        assert_eq!(d.len(), 2);
3931        assert!(
3932            (d[0] - 5.0).abs() < 1e-5,
3933            "L2 pool of [3,4] = {}, expected 5",
3934            d[0]
3935        );
3936        assert!(
3937            (d[1] - 1.0).abs() < 1e-5,
3938            "L2 pool of [1,0] = {}, expected 1",
3939            d[1]
3940        );
3941    }
3942
3943    #[test]
3944    fn test_lppool1d_l1_correctness() {
3945        // L1 pool with kernel=2, stride=2 over [3, -4] => (|3|+|-4|)^1 = 7
3946        let data: Vec<f32> = vec![3.0, -4.0];
3947        let input = leaf_3d(&data, [1, 1, 2], false);
3948        let out = lp_pool1d(&input, 1.0, 2, 2).unwrap();
3949        let d = out.data().unwrap();
3950        assert!((d[0] - 7.0).abs() < 1e-5, "L1 pool = {}, expected 7", d[0]);
3951    }
3952
3953    #[test]
3954    fn test_lppool1d_default_stride() {
3955        // stride=0 should default to kernel_size.
3956        let pool = LPPool1d::new(2.0, 3, 0);
3957        assert_eq!(pool.stride, 3);
3958    }
3959
3960    #[test]
3961    fn test_lppool1d_backward() {
3962        let data: Vec<f32> = vec![3.0, 4.0, 1.0, 2.0];
3963        let input = leaf_3d(&data, [1, 1, 4], true);
3964        let out = lp_pool1d(&input, 2.0, 2, 2).unwrap();
3965
3966        let out_data = out.data().unwrap().to_vec();
3967        let total: f32 = out_data.iter().sum();
3968        let loss = Tensor::from_operation(
3969            TensorStorage::cpu(vec![total]),
3970            vec![],
3971            Arc::new(SumBackward { input: out.clone() }),
3972        )
3973        .unwrap();
3974        loss.backward().unwrap();
3975
3976        let grad = input.grad().unwrap().unwrap();
3977        assert_eq!(grad.shape(), &[1, 1, 4]);
3978        // All gradient values should be finite and non-NaN.
3979        let gd = grad.data().unwrap();
3980        for (i, &g) in gd.iter().enumerate() {
3981            assert!(g.is_finite(), "gradient[{i}] = {g} is not finite");
3982        }
3983    }
3984
3985    #[test]
3986    fn test_lppool1d_no_parameters() {
3987        let pool = LPPool1d::new(2.0, 3, 2);
3988        assert!(Module::<f32>::parameters(&pool).is_empty());
3989    }
3990
3991    // -----------------------------------------------------------------------
3992    // LPPool2d tests — CL-432
3993    // -----------------------------------------------------------------------
3994
3995    #[test]
3996    fn test_lppool2d_output_shape() {
3997        let pool = LPPool2d::new(2.0, [2, 2], [2, 2]);
3998        let input = leaf_4d(&[0.0; 2 * 3 * 4 * 4], [2, 3, 4, 4], false);
3999        let out: Tensor<f32> = Module::<f32>::forward(&pool, &input).unwrap();
4000        assert_eq!(out.shape(), &[2, 3, 2, 2]);
4001    }
4002
4003    #[test]
4004    fn test_lppool2d_l2_correctness() {
4005        // L2 pool with kernel 2x2, stride 2:
4006        // [1, 2, 3, 4] => sqrt(1+4+9+16) = sqrt(30)
4007        #[rustfmt::skip]
4008        let data: Vec<f32> = vec![
4009            1.0, 2.0,
4010            3.0, 4.0,
4011        ];
4012        let input = leaf_4d(&data, [1, 1, 2, 2], false);
4013        let out = lp_pool2d(&input, 2.0, [2, 2], [2, 2]).unwrap();
4014        let d = out.data().unwrap();
4015        assert_eq!(d.len(), 1);
4016        let expected = (1.0f32 + 4.0 + 9.0 + 16.0).sqrt();
4017        assert!(
4018            (d[0] - expected).abs() < 1e-5,
4019            "L2 pool = {}, expected {expected}",
4020            d[0]
4021        );
4022    }
4023
4024    #[test]
4025    fn test_lppool2d_default_stride() {
4026        let pool = LPPool2d::new(2.0, [3, 3], [0, 0]);
4027        assert_eq!(pool.stride, [3, 3]);
4028    }
4029
4030    #[test]
4031    fn test_lppool2d_backward() {
4032        #[rustfmt::skip]
4033        let data: Vec<f32> = vec![
4034            1.0, 2.0, 3.0, 4.0,
4035            5.0, 6.0, 7.0, 8.0,
4036            9.0, 10.0, 11.0, 12.0,
4037            13.0, 14.0, 15.0, 16.0,
4038        ];
4039        let input = leaf_4d(&data, [1, 1, 4, 4], true);
4040        let out = lp_pool2d(&input, 2.0, [2, 2], [2, 2]).unwrap();
4041
4042        let out_data = out.data().unwrap().to_vec();
4043        let total: f32 = out_data.iter().sum();
4044        let loss = Tensor::from_operation(
4045            TensorStorage::cpu(vec![total]),
4046            vec![],
4047            Arc::new(SumBackward { input: out.clone() }),
4048        )
4049        .unwrap();
4050        loss.backward().unwrap();
4051
4052        let grad = input.grad().unwrap().unwrap();
4053        assert_eq!(grad.shape(), &[1, 1, 4, 4]);
4054        let gd = grad.data().unwrap();
4055        for (i, &g) in gd.iter().enumerate() {
4056            assert!(g.is_finite(), "gradient[{i}] = {g} is not finite");
4057        }
4058    }
4059
4060    #[test]
4061    fn test_lppool2d_no_parameters() {
4062        let pool = LPPool2d::new(2.0, [2, 2], [2, 2]);
4063        assert!(Module::<f32>::parameters(&pool).is_empty());
4064    }
4065}