use std::collections::HashMap;
use crate::autograd::AutogradError;
use crate::nn::{Module, Parameter};
use crate::tensor::Tensor;
pub struct AdaptiveAvgPool2d {
pub output_h: usize,
pub output_w: usize,
}
impl AdaptiveAvgPool2d {
pub fn new(output_h: usize, output_w: usize) -> Self {
Self { output_h, output_w }
}
pub fn forward(&self, input: &Tensor) -> Tensor {
input.adaptive_avg_pool2d(self.output_h, self.output_w)
}
}
impl Module for AdaptiveAvgPool2d {
fn parameters(&self) -> Vec<Parameter> { vec![] }
fn state_dict(&self, _prefix: &str) -> HashMap<String, Tensor> { HashMap::new() }
fn load_state_dict(
&mut self,
_dict: &HashMap<String, Tensor>,
_prefix: &str,
) -> Result<(), AutogradError> {
Ok(())
}
}