use crate::utils::function_context;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
fn global_pool_generic<F, G>(
input: &Tensor,
operation_name: &str,
init_value: f32,
reduce_fn: F,
finalize_fn: G,
) -> TorshResult<Tensor>
where
F: Fn(f32, f32) -> f32,
G: Fn(f32, usize) -> f32,
{
let context = function_context(operation_name);
let shape = input.shape();
let dims = shape.dims();
if dims.len() < 3 {
return Err(TorshError::config_error_with_context(
"Input must have at least 3 dimensions (batch, channel, spatial)",
&context,
));
}
let batch_size = dims[0];
let channels = dims[1];
let spatial_size: usize = dims[2..].iter().product();
let input_data = input.to_vec()?;
let mut output_data = vec![init_value; batch_size * channels];
for b in 0..batch_size {
for c in 0..channels {
let mut accumulator = init_value;
for s in 0..spatial_size {
let idx = (b * channels + c) * spatial_size + s;
accumulator = reduce_fn(accumulator, input_data[idx]);
}
let out_idx = b * channels + c;
output_data[out_idx] = finalize_fn(accumulator, spatial_size);
}
}
Tensor::from_data(output_data, vec![batch_size, channels], input.device())
}
pub fn global_avg_pool(input: &Tensor) -> TorshResult<Tensor> {
global_pool_generic(
input,
"global_avg_pool",
0.0, |acc, val| acc + val, |sum, count| sum / count as f32, )
}
pub fn global_max_pool(input: &Tensor) -> TorshResult<Tensor> {
global_pool_generic(
input,
"global_max_pool",
f32::NEG_INFINITY, |acc, val| acc.max(val), |max_val, _count| max_val, )
}