#![warn(missing_docs)]
use std::{error::Error, str::FromStr};
use crate::{Tensor, TensorElement};
use rayon::prelude::*;
pub fn MaxPool<'a, T>(x: &Tensor<'a, T>, stride: usize, padding: usize) -> Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
let pred = |slice: &[T]| {
slice
.par_iter()
.cloned()
.max_by(|a, b| a.partial_cmp(b).unwrap())
};
pool(x, stride, padding, pred)
}
pub fn MinPool<'a, T>(x: &Tensor<'a, T>, stride: usize, padding: usize) -> Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
let pred = |slice: &[T]| {
slice
.par_iter()
.cloned()
.min_by(|a, b| a.partial_cmp(b).unwrap())
};
pool(x, stride, padding, pred)
}
pub fn AvgPool<'a, T>(x: &Tensor<'a, T>, stride: usize, padding: usize) -> Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
let average =
|slice: &[T]| Some(slice.par_iter().cloned().sum::<T>() / T::from(slice.len()).unwrap());
pool(x, stride, padding, average)
}
fn pool<'a, T, F>(x: &Tensor<'a, T>, stride: usize, padding: usize, mut pred: F) -> Tensor<'a, T>
where
F: FnMut(&[T]) -> Option<T> + Send + Sync + 'static,
T: TensorElement,
<T as FromStr>::Err: Error,
{
let out_rows = ((x.shape.iter().nth_back(1).unwrap() + 2 * padding - stride) / stride) + 1;
let out_cols = ((x.shape.iter().nth_back(0).unwrap() + 2 * padding - stride) / stride) + 1;
let mut pooled = Tensor::zeros(vec![out_rows, out_cols]);
for i in 0..out_rows {
for j in 0..out_cols {
let start_row = i * stride;
let start_col = j * stride;
let slice = x.get_vec_slice(vec![start_row, start_col], stride, stride);
if let Some(value) = pred(&slice) {
pooled.set(vec![i, j], value);
}
}
}
pooled
}