use crate::TorshResult;
use torsh_core::TorshError;
use torsh_tensor::{creation::zeros, Tensor};
pub struct SparseMaxPool2d {
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
}
impl SparseMaxPool2d {
pub fn new(
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: Option<(usize, usize)>,
dilation: Option<(usize, usize)>,
) -> Self {
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or((0, 0));
let dilation = dilation.unwrap_or((1, 1));
Self {
kernel_size,
stride,
padding,
dilation,
}
}
pub fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
if input_shape.ndim() != 4 {
return Err(TorshError::InvalidArgument(
"Input must be 4D tensor (batch_size, channels, height, width)".to_string(),
));
}
let batch_size = input_shape.dims()[0];
let channels = input_shape.dims()[1];
let input_height = input_shape.dims()[2];
let input_width = input_shape.dims()[3];
let output_height =
(input_height + 2 * self.padding.0 - self.dilation.0 * (self.kernel_size.0 - 1) - 1)
/ self.stride.0
+ 1;
let output_width =
(input_width + 2 * self.padding.1 - self.dilation.1 * (self.kernel_size.1 - 1) - 1)
/ self.stride.1
+ 1;
let mut output = zeros::<f32>(&[batch_size, channels, output_height, output_width])?;
for b in 0..batch_size {
for c in 0..channels {
self.max_pool_channel(
input,
&mut output,
b,
c,
input_height,
input_width,
output_height,
output_width,
)?;
}
}
Ok(output)
}
#[allow(clippy::too_many_arguments)]
fn max_pool_channel(
&self,
input: &Tensor,
output: &mut Tensor,
batch_idx: usize,
channel_idx: usize,
input_height: usize,
input_width: usize,
output_height: usize,
output_width: usize,
) -> TorshResult<()> {
for out_h in 0..output_height {
for out_w in 0..output_width {
let mut max_val = f32::NEG_INFINITY;
let mut found_value = false;
for kh in 0..self.kernel_size.0 {
for kw in 0..self.kernel_size.1 {
let in_h = out_h * self.stride.0 + kh * self.dilation.0;
let in_w = out_w * self.stride.1 + kw * self.dilation.1;
if in_h >= self.padding.0 && in_w >= self.padding.1 {
let padded_in_h = in_h - self.padding.0;
let padded_in_w = in_w - self.padding.1;
if padded_in_h < input_height && padded_in_w < input_width {
let val = input.get(&[
batch_idx,
channel_idx,
padded_in_h,
padded_in_w,
])?;
if val > max_val || !found_value {
max_val = val;
found_value = true;
}
}
}
}
}
output.set(
&[batch_idx, channel_idx, out_h, out_w],
if found_value { max_val } else { 0.0 },
)?;
}
}
Ok(())
}
}
pub struct SparseAvgPool2d {
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
count_include_pad: bool,
}
impl SparseAvgPool2d {
pub fn new(
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: Option<(usize, usize)>,
count_include_pad: bool,
) -> Self {
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or((0, 0));
Self {
kernel_size,
stride,
padding,
count_include_pad,
}
}
pub fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
if input_shape.ndim() != 4 {
return Err(TorshError::InvalidArgument(
"Input must be 4D tensor (batch_size, channels, height, width)".to_string(),
));
}
let batch_size = input_shape.dims()[0];
let channels = input_shape.dims()[1];
let input_height = input_shape.dims()[2];
let input_width = input_shape.dims()[3];
let output_height =
(input_height + 2 * self.padding.0 - self.kernel_size.0) / self.stride.0 + 1;
let output_width =
(input_width + 2 * self.padding.1 - self.kernel_size.1) / self.stride.1 + 1;
let mut output = zeros::<f32>(&[batch_size, channels, output_height, output_width])?;
for b in 0..batch_size {
for c in 0..channels {
self.avg_pool_channel(
input,
&mut output,
b,
c,
input_height,
input_width,
output_height,
output_width,
)?;
}
}
Ok(output)
}
#[allow(clippy::too_many_arguments)]
fn avg_pool_channel(
&self,
input: &Tensor,
output: &mut Tensor,
batch_idx: usize,
channel_idx: usize,
input_height: usize,
input_width: usize,
output_height: usize,
output_width: usize,
) -> TorshResult<()> {
for out_h in 0..output_height {
for out_w in 0..output_width {
let mut sum = 0.0;
let mut count = 0;
for kh in 0..self.kernel_size.0 {
for kw in 0..self.kernel_size.1 {
let in_h = out_h * self.stride.0 + kh;
let in_w = out_w * self.stride.1 + kw;
if in_h >= self.padding.0 && in_w >= self.padding.1 {
let padded_in_h = in_h - self.padding.0;
let padded_in_w = in_w - self.padding.1;
if padded_in_h < input_height && padded_in_w < input_width {
let val = input.get(&[
batch_idx,
channel_idx,
padded_in_h,
padded_in_w,
])?;
sum += val;
count += 1;
} else if self.count_include_pad {
count += 1; }
} else if self.count_include_pad {
count += 1; }
}
}
let avg = if count > 0 { sum / count as f32 } else { 0.0 };
output.set(&[batch_idx, channel_idx, out_h, out_w], avg)?;
}
}
Ok(())
}
}
pub struct SparseMaxPool1d {
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
}
impl SparseMaxPool1d {
pub fn new(
kernel_size: usize,
stride: Option<usize>,
padding: Option<usize>,
dilation: Option<usize>,
) -> Self {
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or(0);
let dilation = dilation.unwrap_or(1);
Self {
kernel_size,
stride,
padding,
dilation,
}
}
pub fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
if input_shape.ndim() != 3 {
return Err(TorshError::InvalidArgument(
"Input must be 3D tensor (batch_size, channels, length)".to_string(),
));
}
let batch_size = input_shape.dims()[0];
let channels = input_shape.dims()[1];
let input_length = input_shape.dims()[2];
let output_length =
(input_length + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)
/ self.stride
+ 1;
let output = zeros::<f32>(&[batch_size, channels, output_length])?;
for b in 0..batch_size {
for c in 0..channels {
for out_pos in 0..output_length {
let mut max_val = f32::NEG_INFINITY;
let mut found_value = false;
for k in 0..self.kernel_size {
let in_pos = out_pos * self.stride + k * self.dilation;
if in_pos >= self.padding {
let padded_in_pos = in_pos - self.padding;
if padded_in_pos < input_length {
let val = input.get(&[b, c, padded_in_pos])?;
if val > max_val || !found_value {
max_val = val;
found_value = true;
}
}
}
}
output.set(&[b, c, out_pos], if found_value { max_val } else { 0.0 })?;
}
}
}
Ok(output)
}
}
pub struct SparseAvgPool1d {
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
}
impl SparseAvgPool1d {
pub fn new(
kernel_size: usize,
stride: Option<usize>,
padding: Option<usize>,
count_include_pad: bool,
) -> Self {
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or(0);
Self {
kernel_size,
stride,
padding,
count_include_pad,
}
}
pub fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
if input_shape.ndim() != 3 {
return Err(TorshError::InvalidArgument(
"Input must be 3D tensor (batch_size, channels, length)".to_string(),
));
}
let batch_size = input_shape.dims()[0];
let channels = input_shape.dims()[1];
let input_length = input_shape.dims()[2];
let output_length = (input_length + 2 * self.padding - self.kernel_size) / self.stride + 1;
let output = zeros::<f32>(&[batch_size, channels, output_length])?;
for b in 0..batch_size {
for c in 0..channels {
for out_pos in 0..output_length {
let mut sum = 0.0;
let mut count = 0;
for k in 0..self.kernel_size {
let in_pos = out_pos * self.stride + k;
if in_pos >= self.padding {
let padded_in_pos = in_pos - self.padding;
if padded_in_pos < input_length {
let val = input.get(&[b, c, padded_in_pos])?;
sum += val;
count += 1;
} else if self.count_include_pad {
count += 1; }
} else if self.count_include_pad {
count += 1; }
}
let avg = if count > 0 { sum / count as f32 } else { 0.0 };
output.set(&[b, c, out_pos], avg)?;
}
}
}
Ok(output)
}
}
pub struct SparseAdaptiveMaxPool2d {
output_size: (usize, usize),
}
impl SparseAdaptiveMaxPool2d {
pub fn new(output_size: (usize, usize)) -> Self {
Self { output_size }
}
pub fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
if input_shape.ndim() != 4 {
return Err(TorshError::InvalidArgument(
"Input must be 4D tensor (batch_size, channels, height, width)".to_string(),
));
}
let batch_size = input_shape.dims()[0];
let channels = input_shape.dims()[1];
let input_height = input_shape.dims()[2];
let input_width = input_shape.dims()[3];
let output_height = self.output_size.0;
let output_width = self.output_size.1;
let output = zeros::<f32>(&[batch_size, channels, output_height, output_width])?;
for b in 0..batch_size {
for c in 0..channels {
for out_h in 0..output_height {
for out_w in 0..output_width {
let start_h = (out_h * input_height) / output_height;
let end_h = ((out_h + 1) * input_height).div_ceil(output_height);
let start_w = (out_w * input_width) / output_width;
let end_w = ((out_w + 1) * input_width).div_ceil(output_width);
let mut max_val = f32::NEG_INFINITY;
let mut found_value = false;
for h in start_h..end_h {
for w in start_w..end_w {
if h < input_height && w < input_width {
let val = input.get(&[b, c, h, w])?;
if val > max_val || !found_value {
max_val = val;
found_value = true;
}
}
}
}
output.set(
&[b, c, out_h, out_w],
if found_value { max_val } else { 0.0 },
)?;
}
}
}
}
Ok(output)
}
}
pub struct SparseAdaptiveAvgPool2d {
output_size: (usize, usize),
}
impl SparseAdaptiveAvgPool2d {
pub fn new(output_size: (usize, usize)) -> Self {
Self { output_size }
}
pub fn forward(&self, input: &Tensor) -> TorshResult<Tensor> {
let input_shape = input.shape();
if input_shape.ndim() != 4 {
return Err(TorshError::InvalidArgument(
"Input must be 4D tensor (batch_size, channels, height, width)".to_string(),
));
}
let batch_size = input_shape.dims()[0];
let channels = input_shape.dims()[1];
let input_height = input_shape.dims()[2];
let input_width = input_shape.dims()[3];
let output_height = self.output_size.0;
let output_width = self.output_size.1;
let output = zeros::<f32>(&[batch_size, channels, output_height, output_width])?;
for b in 0..batch_size {
for c in 0..channels {
for out_h in 0..output_height {
for out_w in 0..output_width {
let start_h = (out_h * input_height) / output_height;
let end_h = ((out_h + 1) * input_height).div_ceil(output_height);
let start_w = (out_w * input_width) / output_width;
let end_w = ((out_w + 1) * input_width).div_ceil(output_width);
let mut sum = 0.0;
let mut count = 0;
for h in start_h..end_h {
for w in start_w..end_w {
if h < input_height && w < input_width {
let val = input.get(&[b, c, h, w])?;
sum += val;
count += 1;
}
}
}
let avg = if count > 0 { sum / count as f32 } else { 0.0 };
output.set(&[b, c, out_h, out_w], avg)?;
}
}
}
}
Ok(output)
}
}