use torsh_core::error::Result;
use torsh_tensor::Tensor;
pub fn max_pool1d(
input: &Tensor,
kernel_size: usize,
stride: Option<usize>,
padding: Option<usize>,
dilation: Option<usize>,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 3 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0],
got: input_shape.to_vec(),
});
}
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or(0);
let dilation = dilation.unwrap_or(1);
let [batch, channels, length] = [input_shape[0], input_shape[1], input_shape[2]];
let out_length = (length + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1;
let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_length];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for out_l in 0..out_length {
let l_start = out_l * stride;
let mut max_val = f32::NEG_INFINITY;
for k in 0..kernel_size {
let l = l_start + k * dilation;
if l >= padding && l < length + padding {
let input_l = l - padding;
let input_idx = b * (channels * length) + c * length + input_l;
max_val = max_val.max(input_data[input_idx]);
}
}
let output_idx = b * (channels * out_length) + c * out_length + out_l;
output_data[output_idx] = max_val;
}
}
}
let output_shape = vec![batch, channels, out_length];
Tensor::from_vec(output_data, &output_shape)
}
pub fn max_pool2d(
input: &Tensor,
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: Option<(usize, usize)>,
dilation: Option<(usize, usize)>,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or((0, 0));
let dilation = dilation.unwrap_or((1, 1));
let [batch, channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
let out_height = (height + 2 * padding.0 - dilation.0 * (kernel_size.0 - 1) - 1) / stride.0 + 1;
let out_width = (width + 2 * padding.1 - dilation.1 * (kernel_size.1 - 1) - 1) / stride.1 + 1;
let mut output_data = vec![f32::NEG_INFINITY; batch * channels * out_height * out_width];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for out_h in 0..out_height {
for out_w in 0..out_width {
let h_start = out_h * stride.0;
let w_start = out_w * stride.1;
let mut max_val = f32::NEG_INFINITY;
for kh in 0..kernel_size.0 {
for kw in 0..kernel_size.1 {
let h = h_start + kh * dilation.0;
let w = w_start + kw * dilation.1;
if h >= padding.0
&& h < height + padding.0
&& w >= padding.1
&& w < width + padding.1
{
let input_h = h - padding.0;
let input_w = w - padding.1;
let input_idx = b * (channels * height * width)
+ c * (height * width)
+ input_h * width
+ input_w;
max_val = max_val.max(input_data[input_idx]);
}
}
}
let output_idx = b * (channels * out_height * out_width)
+ c * (out_height * out_width)
+ out_h * out_width
+ out_w;
output_data[output_idx] = max_val;
}
}
}
}
let output_shape = vec![batch, channels, out_height, out_width];
Tensor::from_vec(output_data, &output_shape)
}
pub fn max_pool3d(
input: &Tensor,
kernel_size: (usize, usize, usize),
stride: Option<(usize, usize, usize)>,
padding: Option<(usize, usize, usize)>,
dilation: Option<(usize, usize, usize)>,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 5 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or((0, 0, 0));
let dilation = dilation.unwrap_or((1, 1, 1));
let [batch, channels, depth, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
input_shape[4],
];
let out_depth = (depth + 2 * padding.0 - dilation.0 * (kernel_size.0 - 1) - 1) / stride.0 + 1;
let out_height = (height + 2 * padding.1 - dilation.1 * (kernel_size.1 - 1) - 1) / stride.1 + 1;
let out_width = (width + 2 * padding.2 - dilation.2 * (kernel_size.2 - 1) - 1) / stride.2 + 1;
let mut output_data =
vec![f32::NEG_INFINITY; batch * channels * out_depth * out_height * out_width];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for out_d in 0..out_depth {
for out_h in 0..out_height {
for out_w in 0..out_width {
let d_start = out_d * stride.0;
let h_start = out_h * stride.1;
let w_start = out_w * stride.2;
let mut max_val = f32::NEG_INFINITY;
for kd in 0..kernel_size.0 {
for kh in 0..kernel_size.1 {
for kw in 0..kernel_size.2 {
let d = d_start + kd * dilation.0;
let h = h_start + kh * dilation.1;
let w = w_start + kw * dilation.2;
if d >= padding.0
&& d < depth + padding.0
&& h >= padding.1
&& h < height + padding.1
&& w >= padding.2
&& w < width + padding.2
{
let input_d = d - padding.0;
let input_h = h - padding.1;
let input_w = w - padding.2;
let input_idx = b * (channels * depth * height * width)
+ c * (depth * height * width)
+ input_d * (height * width)
+ input_h * width
+ input_w;
max_val = max_val.max(input_data[input_idx]);
}
}
}
}
let output_idx = b * (channels * out_depth * out_height * out_width)
+ c * (out_depth * out_height * out_width)
+ out_d * (out_height * out_width)
+ out_h * out_width
+ out_w;
output_data[output_idx] = max_val;
}
}
}
}
}
let output_shape = vec![batch, channels, out_depth, out_height, out_width];
Tensor::from_vec(output_data, &output_shape)
}
pub fn avg_pool1d(
input: &Tensor,
kernel_size: usize,
stride: Option<usize>,
padding: Option<usize>,
count_include_pad: bool,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 3 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0],
got: input_shape.to_vec(),
});
}
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or(0);
let [batch, channels, length] = [input_shape[0], input_shape[1], input_shape[2]];
let out_length = (length + 2 * padding - kernel_size) / stride + 1;
let mut output_data = vec![0.0f32; batch * channels * out_length];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for out_l in 0..out_length {
let l_start = out_l * stride;
let mut sum = 0.0f32;
let mut count = 0;
for k in 0..kernel_size {
let l = l_start + k;
if l < padding || l >= length + padding {
if count_include_pad {
count += 1;
}
} else {
let input_l = l - padding;
let input_idx = b * (channels * length) + c * length + input_l;
sum += input_data[input_idx];
count += 1;
}
}
let avg = if count > 0 { sum / count as f32 } else { 0.0 };
let output_idx = b * (channels * out_length) + c * out_length + out_l;
output_data[output_idx] = avg;
}
}
}
let output_shape = vec![batch, channels, out_length];
Tensor::from_vec(output_data, &output_shape)
}
pub fn avg_pool2d(
input: &Tensor,
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: Option<(usize, usize)>,
count_include_pad: bool,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or((0, 0));
let [batch, channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
let out_height = (height + 2 * padding.0 - kernel_size.0) / stride.0 + 1;
let out_width = (width + 2 * padding.1 - kernel_size.1) / stride.1 + 1;
let mut output_data = vec![0.0f32; batch * channels * out_height * out_width];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for out_h in 0..out_height {
for out_w in 0..out_width {
let h_start = out_h * stride.0;
let w_start = out_w * stride.1;
let mut sum = 0.0f32;
let mut count = 0;
for kh in 0..kernel_size.0 {
for kw in 0..kernel_size.1 {
let h = h_start + kh;
let w = w_start + kw;
if h < padding.0
|| h >= height + padding.0
|| w < padding.1
|| w >= width + padding.1
{
if count_include_pad {
count += 1;
}
} else {
let input_h = h - padding.0;
let input_w = w - padding.1;
let input_idx = b * (channels * height * width)
+ c * (height * width)
+ input_h * width
+ input_w;
sum += input_data[input_idx];
count += 1;
}
}
}
let avg = if count > 0 { sum / count as f32 } else { 0.0 };
let output_idx = b * (channels * out_height * out_width)
+ c * (out_height * out_width)
+ out_h * out_width
+ out_w;
output_data[output_idx] = avg;
}
}
}
}
let output_shape = vec![batch, channels, out_height, out_width];
Tensor::from_vec(output_data, &output_shape)
}
pub fn avg_pool3d(
input: &Tensor,
kernel_size: (usize, usize, usize),
stride: Option<(usize, usize, usize)>,
padding: Option<(usize, usize, usize)>,
count_include_pad: bool,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 5 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let stride = stride.unwrap_or(kernel_size);
let padding = padding.unwrap_or((0, 0, 0));
let [batch, channels, depth, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
input_shape[4],
];
let out_depth = (depth + 2 * padding.0 - kernel_size.0) / stride.0 + 1;
let out_height = (height + 2 * padding.1 - kernel_size.1) / stride.1 + 1;
let out_width = (width + 2 * padding.2 - kernel_size.2) / stride.2 + 1;
let mut output_data = vec![0.0f32; batch * channels * out_depth * out_height * out_width];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for out_d in 0..out_depth {
for out_h in 0..out_height {
for out_w in 0..out_width {
let d_start = out_d * stride.0;
let h_start = out_h * stride.1;
let w_start = out_w * stride.2;
let mut sum = 0.0f32;
let mut count = 0;
for kd in 0..kernel_size.0 {
for kh in 0..kernel_size.1 {
for kw in 0..kernel_size.2 {
let d = d_start + kd;
let h = h_start + kh;
let w = w_start + kw;
if d < padding.0
|| d >= depth + padding.0
|| h < padding.1
|| h >= height + padding.1
|| w < padding.2
|| w >= width + padding.2
{
if count_include_pad {
count += 1;
}
} else {
let input_d = d - padding.0;
let input_h = h - padding.1;
let input_w = w - padding.2;
let input_idx = b * (channels * depth * height * width)
+ c * (depth * height * width)
+ input_d * (height * width)
+ input_h * width
+ input_w;
sum += input_data[input_idx];
count += 1;
}
}
}
}
let avg = if count > 0 { sum / count as f32 } else { 0.0 };
let output_idx = b * (channels * out_depth * out_height * out_width)
+ c * (out_depth * out_height * out_width)
+ out_d * (out_height * out_width)
+ out_h * out_width
+ out_w;
output_data[output_idx] = avg;
}
}
}
}
}
let output_shape = vec![batch, channels, out_depth, out_height, out_width];
Tensor::from_vec(output_data, &output_shape)
}
pub fn adaptive_avg_pool1d(input: &Tensor, output_size: usize) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() < 1 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0],
got: input_shape.to_vec(),
});
}
let input_length = input_shape[input_shape.len() - 1];
if output_size == input_length {
return Ok(input.clone());
}
let mut output_shape = input_shape.to_vec();
let length_idx = output_shape.len() - 1;
output_shape[length_idx] = output_size;
let batch_size: usize = input_shape[..input_shape.len() - 1].iter().product();
let mut output_data = vec![0.0f32; batch_size * output_size];
let input_data = input.to_vec()?;
for b in 0..batch_size {
for out_l in 0..output_size {
let start_idx = (out_l * input_length) / output_size;
let end_idx = ((out_l + 1) * input_length) / output_size;
let mut sum = 0.0f32;
let mut count = 0;
for in_l in start_idx..end_idx {
let input_idx = b * input_length + in_l;
sum += input_data[input_idx];
count += 1;
}
let avg = if count > 0 { sum / count as f32 } else { 0.0 };
let output_idx = b * output_size + out_l;
output_data[output_idx] = avg;
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn adaptive_avg_pool2d(
input: &Tensor,
output_size: (Option<usize>, Option<usize>),
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() < 2 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0],
got: input_shape.to_vec(),
});
}
let input_height = input_shape[input_shape.len() - 2];
let input_width = input_shape[input_shape.len() - 1];
let output_height = output_size.0.unwrap_or(input_height);
let output_width = output_size.1.unwrap_or(input_width);
if output_height == input_height && output_width == input_width {
return Ok(input.clone());
}
let mut output_shape = input_shape.to_vec();
let height_idx = output_shape.len() - 2;
let width_idx = output_shape.len() - 1;
output_shape[height_idx] = output_height;
output_shape[width_idx] = output_width;
let batch_size: usize = input_shape[..input_shape.len() - 2].iter().product();
let mut output_data = vec![0.0f32; batch_size * output_height * output_width];
let input_data = input.to_vec()?;
for b in 0..batch_size {
for out_h in 0..output_height {
for out_w in 0..output_width {
let h_start = (out_h * input_height) / output_height;
let h_end = ((out_h + 1) * input_height) / output_height;
let w_start = (out_w * input_width) / output_width;
let w_end = ((out_w + 1) * input_width) / output_width;
let mut sum = 0.0f32;
let mut count = 0;
for in_h in h_start..h_end {
for in_w in w_start..w_end {
let input_idx =
b * (input_height * input_width) + in_h * input_width + in_w;
sum += input_data[input_idx];
count += 1;
}
}
let avg = if count > 0 { sum / count as f32 } else { 0.0 };
let output_idx = b * (output_height * output_width) + out_h * output_width + out_w;
output_data[output_idx] = avg;
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn adaptive_avg_pool3d(
input: &Tensor,
output_size: (Option<usize>, Option<usize>, Option<usize>),
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() < 3 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0],
got: input_shape.to_vec(),
});
}
let input_depth = input_shape[input_shape.len() - 3];
let input_height = input_shape[input_shape.len() - 2];
let input_width = input_shape[input_shape.len() - 1];
let output_depth = output_size.0.unwrap_or(input_depth);
let output_height = output_size.1.unwrap_or(input_height);
let output_width = output_size.2.unwrap_or(input_width);
if output_depth == input_depth && output_height == input_height && output_width == input_width {
return Ok(input.clone());
}
let mut output_shape = input_shape.to_vec();
let depth_idx = output_shape.len() - 3;
let height_idx = output_shape.len() - 2;
let width_idx = output_shape.len() - 1;
output_shape[depth_idx] = output_depth;
output_shape[height_idx] = output_height;
output_shape[width_idx] = output_width;
let batch_size: usize = input_shape[..input_shape.len() - 3].iter().product();
let mut output_data = vec![0.0f32; batch_size * output_depth * output_height * output_width];
let input_data = input.to_vec()?;
for b in 0..batch_size {
for out_d in 0..output_depth {
for out_h in 0..output_height {
for out_w in 0..output_width {
let d_start = (out_d * input_depth) / output_depth;
let d_end = ((out_d + 1) * input_depth) / output_depth;
let h_start = (out_h * input_height) / output_height;
let h_end = ((out_h + 1) * input_height) / output_height;
let w_start = (out_w * input_width) / output_width;
let w_end = ((out_w + 1) * input_width) / output_width;
let mut sum = 0.0f32;
let mut count = 0;
for in_d in d_start..d_end {
for in_h in h_start..h_end {
for in_w in w_start..w_end {
let input_idx = b * (input_depth * input_height * input_width)
+ in_d * (input_height * input_width)
+ in_h * input_width
+ in_w;
sum += input_data[input_idx];
count += 1;
}
}
}
let avg = if count > 0 { sum / count as f32 } else { 0.0 };
let output_idx = b * (output_depth * output_height * output_width)
+ out_d * (output_height * output_width)
+ out_h * output_width
+ out_w;
output_data[output_idx] = avg;
}
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn adaptive_max_pool1d(input: &Tensor, output_size: usize) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() < 1 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0],
got: input_shape.to_vec(),
});
}
let input_length = input_shape[input_shape.len() - 1];
if output_size == input_length {
return Ok(input.clone());
}
let mut output_shape = input_shape.to_vec();
let length_idx = output_shape.len() - 1;
output_shape[length_idx] = output_size;
let batch_size: usize = input_shape[..input_shape.len() - 1].iter().product();
let mut output_data = vec![f32::NEG_INFINITY; batch_size * output_size];
let input_data = input.to_vec()?;
for b in 0..batch_size {
for out_l in 0..output_size {
let start_idx = (out_l * input_length) / output_size;
let end_idx = ((out_l + 1) * input_length) / output_size;
let mut max_val = f32::NEG_INFINITY;
for in_l in start_idx..end_idx {
let input_idx = b * input_length + in_l;
max_val = max_val.max(input_data[input_idx]);
}
let output_idx = b * output_size + out_l;
output_data[output_idx] = max_val;
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn adaptive_max_pool2d(
input: &Tensor,
output_size: (Option<usize>, Option<usize>),
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() < 2 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0],
got: input_shape.to_vec(),
});
}
let input_height = input_shape[input_shape.len() - 2];
let input_width = input_shape[input_shape.len() - 1];
let output_height = output_size.0.unwrap_or(input_height);
let output_width = output_size.1.unwrap_or(input_width);
if output_height == input_height && output_width == input_width {
return Ok(input.clone());
}
let mut output_shape = input_shape.to_vec();
let height_idx = output_shape.len() - 2;
let width_idx = output_shape.len() - 1;
output_shape[height_idx] = output_height;
output_shape[width_idx] = output_width;
let batch_size: usize = input_shape[..input_shape.len() - 2].iter().product();
let mut output_data = vec![f32::NEG_INFINITY; batch_size * output_height * output_width];
let input_data = input.to_vec()?;
for b in 0..batch_size {
for out_h in 0..output_height {
for out_w in 0..output_width {
let h_start = (out_h * input_height) / output_height;
let h_end = ((out_h + 1) * input_height) / output_height;
let w_start = (out_w * input_width) / output_width;
let w_end = ((out_w + 1) * input_width) / output_width;
let mut max_val = f32::NEG_INFINITY;
for in_h in h_start..h_end {
for in_w in w_start..w_end {
let input_idx =
b * (input_height * input_width) + in_h * input_width + in_w;
max_val = max_val.max(input_data[input_idx]);
}
}
let output_idx = b * (output_height * output_width) + out_h * output_width + out_w;
output_data[output_idx] = max_val;
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn adaptive_max_pool3d(
input: &Tensor,
output_size: (Option<usize>, Option<usize>, Option<usize>),
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() < 3 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0],
got: input_shape.to_vec(),
});
}
let input_depth = input_shape[input_shape.len() - 3];
let input_height = input_shape[input_shape.len() - 2];
let input_width = input_shape[input_shape.len() - 1];
let output_depth = output_size.0.unwrap_or(input_depth);
let output_height = output_size.1.unwrap_or(input_height);
let output_width = output_size.2.unwrap_or(input_width);
if output_depth == input_depth && output_height == input_height && output_width == input_width {
return Ok(input.clone());
}
let mut output_shape = input_shape.to_vec();
let depth_idx = output_shape.len() - 3;
let height_idx = output_shape.len() - 2;
let width_idx = output_shape.len() - 1;
output_shape[depth_idx] = output_depth;
output_shape[height_idx] = output_height;
output_shape[width_idx] = output_width;
let batch_size: usize = input_shape[..input_shape.len() - 3].iter().product();
let mut output_data =
vec![f32::NEG_INFINITY; batch_size * output_depth * output_height * output_width];
let input_data = input.to_vec()?;
for b in 0..batch_size {
for out_d in 0..output_depth {
for out_h in 0..output_height {
for out_w in 0..output_width {
let d_start = (out_d * input_depth) / output_depth;
let d_end = ((out_d + 1) * input_depth) / output_depth;
let h_start = (out_h * input_height) / output_height;
let h_end = ((out_h + 1) * input_height) / output_height;
let w_start = (out_w * input_width) / output_width;
let w_end = ((out_w + 1) * input_width) / output_width;
let mut max_val = f32::NEG_INFINITY;
for in_d in d_start..d_end {
for in_h in h_start..h_end {
for in_w in w_start..w_end {
let input_idx = b * (input_depth * input_height * input_width)
+ in_d * (input_height * input_width)
+ in_h * input_width
+ in_w;
max_val = max_val.max(input_data[input_idx]);
}
}
}
let output_idx = b * (output_depth * output_height * output_width)
+ out_d * (output_height * output_width)
+ out_h * output_width
+ out_w;
output_data[output_idx] = max_val;
}
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn global_avg_pool1d(input: &Tensor) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 3 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0],
got: input_shape.to_vec(),
});
}
let [batch, channels, length] = [input_shape[0], input_shape[1], input_shape[2]];
let output_shape = vec![batch, channels, 1];
let mut output_data = vec![0.0f32; batch * channels];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
let mut sum = 0.0f32;
for l in 0..length {
let idx = b * (channels * length) + c * length + l;
sum += input_data[idx];
}
let avg = sum / length as f32;
let output_idx = b * channels + c;
output_data[output_idx] = avg;
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn global_avg_pool2d(input: &Tensor) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let [batch, channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
let output_shape = vec![batch, channels, 1, 1];
let mut output_data = vec![0.0f32; batch * channels];
let input_data = input.to_vec()?;
let spatial_size = (height * width) as f32;
for b in 0..batch {
for c in 0..channels {
let mut sum = 0.0f32;
for h in 0..height {
for w in 0..width {
let idx =
b * (channels * height * width) + c * (height * width) + h * width + w;
sum += input_data[idx];
}
}
let avg = sum / spatial_size;
let output_idx = b * channels + c;
output_data[output_idx] = avg;
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn global_avg_pool3d(input: &Tensor) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 5 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let [batch, channels, depth, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
input_shape[4],
];
let output_shape = vec![batch, channels, 1, 1, 1];
let mut output_data = vec![0.0f32; batch * channels];
let input_data = input.to_vec()?;
let spatial_size = (depth * height * width) as f32;
for b in 0..batch {
for c in 0..channels {
let mut sum = 0.0f32;
for d in 0..depth {
for h in 0..height {
for w in 0..width {
let idx = b * (channels * depth * height * width)
+ c * (depth * height * width)
+ d * (height * width)
+ h * width
+ w;
sum += input_data[idx];
}
}
}
let avg = sum / spatial_size;
let output_idx = b * channels + c;
output_data[output_idx] = avg;
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn global_max_pool1d(input: &Tensor) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 3 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0],
got: input_shape.to_vec(),
});
}
let [batch, channels, length] = [input_shape[0], input_shape[1], input_shape[2]];
let output_shape = vec![batch, channels, 1];
let mut output_data = vec![f32::NEG_INFINITY; batch * channels];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
let mut max_val = f32::NEG_INFINITY;
for l in 0..length {
let idx = b * (channels * length) + c * length + l;
max_val = max_val.max(input_data[idx]);
}
let output_idx = b * channels + c;
output_data[output_idx] = max_val;
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn global_max_pool2d(input: &Tensor) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let [batch, channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
let output_shape = vec![batch, channels, 1, 1];
let mut output_data = vec![f32::NEG_INFINITY; batch * channels];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
let mut max_val = f32::NEG_INFINITY;
for h in 0..height {
for w in 0..width {
let idx =
b * (channels * height * width) + c * (height * width) + h * width + w;
max_val = max_val.max(input_data[idx]);
}
}
let output_idx = b * channels + c;
output_data[output_idx] = max_val;
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn global_max_pool3d(input: &Tensor) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 5 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let [batch, channels, depth, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
input_shape[4],
];
let output_shape = vec![batch, channels, 1, 1, 1];
let mut output_data = vec![f32::NEG_INFINITY; batch * channels];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
let mut max_val = f32::NEG_INFINITY;
for d in 0..depth {
for h in 0..height {
for w in 0..width {
let idx = b * (channels * depth * height * width)
+ c * (depth * height * width)
+ d * (height * width)
+ h * width
+ w;
max_val = max_val.max(input_data[idx]);
}
}
}
let output_idx = b * channels + c;
output_data[output_idx] = max_val;
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn pad(
input: &Tensor,
padding: &[(usize, usize)],
mode: &str,
value: Option<f32>,
) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
match (input_shape.len(), padding.len()) {
(3, 1) => {
let (pad_left, pad_right) = padding[0];
match mode {
"reflect" => reflection_pad1d(input, (pad_left, pad_right)),
"replicate" => replication_pad1d(input, (pad_left, pad_right)),
"constant" | "zero" => {
let fill_value = if mode == "zero" {
0.0
} else {
value.unwrap_or(0.0)
};
let [batch, channels, length] =
[input_shape[0], input_shape[1], input_shape[2]];
let new_length = length + pad_left + pad_right;
let output_shape = vec![batch, channels, new_length];
let mut output_data = vec![fill_value; batch * channels * new_length];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for l in 0..length {
let input_idx = b * (channels * length) + c * length + l;
let output_idx =
b * (channels * new_length) + c * new_length + (l + pad_left);
output_data[output_idx] = input_data[input_idx];
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
_ => Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Unsupported padding mode: {}",
mode
))),
}
}
(4, 2) => {
let (pad_left, pad_right) = padding[1];
let (pad_top, pad_bottom) = padding[0];
match mode {
"reflect" => reflection_pad2d(input, (pad_left, pad_right, pad_top, pad_bottom)),
"replicate" => replication_pad2d(input, (pad_left, pad_right, pad_top, pad_bottom)),
"constant" | "zero" => {
let fill_value = if mode == "zero" {
0.0
} else {
value.unwrap_or(0.0)
};
if fill_value == 0.0 {
zero_pad2d(input, (pad_left, pad_right, pad_top, pad_bottom))
} else {
let [batch, channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
let new_height = height + pad_top + pad_bottom;
let new_width = width + pad_left + pad_right;
let output_shape = vec![batch, channels, new_height, new_width];
let mut output_data =
vec![fill_value; batch * channels * new_height * new_width];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
let input_idx = b * (channels * height * width)
+ c * (height * width)
+ h * width
+ w;
let output_h = h + pad_top;
let output_w = w + pad_left;
let output_idx = b * (channels * new_height * new_width)
+ c * (new_height * new_width)
+ output_h * new_width
+ output_w;
output_data[output_idx] = input_data[input_idx];
}
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
}
_ => Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Unsupported padding mode: {}",
mode
))),
}
}
_ => Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Unsupported combination: tensor shape length {} with padding length {}. \
Expected 3D tensor with 1 padding pair or 4D tensor with 2 padding pairs.",
input_shape.len(),
padding.len()
))),
}
}
pub fn reflection_pad1d(input: &Tensor, padding: (usize, usize)) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 3 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0],
got: input_shape.to_vec(),
});
}
let (pad_left, pad_right) = padding;
let [batch, channels, length] = [input_shape[0], input_shape[1], input_shape[2]];
if pad_left >= length || pad_right >= length {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Padding size ({}, {}) cannot be >= input length ({}) for reflection padding",
pad_left, pad_right, length
)));
}
let new_length = length + pad_left + pad_right;
let output_shape = vec![batch, channels, new_length];
let mut output_data = vec![0.0f32; batch * channels * new_length];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for l in 0..new_length {
let input_idx = b * (channels * length) + c * length;
let output_idx = b * (channels * new_length) + c * new_length + l;
let src_l = if l < pad_left {
pad_left - l
} else if l >= pad_left + length {
let offset = l - (pad_left + length);
length - 2 - offset
} else {
l - pad_left
};
output_data[output_idx] = input_data[input_idx + src_l];
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn reflection_pad2d(input: &Tensor, padding: (usize, usize, usize, usize)) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let (pad_left, pad_right, pad_top, pad_bottom) = padding;
let [batch, channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
if pad_left >= width || pad_right >= width {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Horizontal padding ({}, {}) cannot be >= input width ({}) for reflection padding",
pad_left, pad_right, width
)));
}
if pad_top >= height || pad_bottom >= height {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Vertical padding ({}, {}) cannot be >= input height ({}) for reflection padding",
pad_top, pad_bottom, height
)));
}
let new_height = height + pad_top + pad_bottom;
let new_width = width + pad_left + pad_right;
let output_shape = vec![batch, channels, new_height, new_width];
let mut output_data = vec![0.0f32; batch * channels * new_height * new_width];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for h in 0..new_height {
for w in 0..new_width {
let src_h = if h < pad_top {
pad_top - h
} else if h >= pad_top + height {
let offset = h - (pad_top + height);
height - 2 - offset
} else {
h - pad_top
};
let src_w = if w < pad_left {
pad_left - w
} else if w >= pad_left + width {
let offset = w - (pad_left + width);
width - 2 - offset
} else {
w - pad_left
};
let input_idx = b * (channels * height * width)
+ c * (height * width)
+ src_h * width
+ src_w;
let output_idx = b * (channels * new_height * new_width)
+ c * (new_height * new_width)
+ h * new_width
+ w;
output_data[output_idx] = input_data[input_idx];
}
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn replication_pad1d(input: &Tensor, padding: (usize, usize)) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 3 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0],
got: input_shape.to_vec(),
});
}
let (pad_left, pad_right) = padding;
let [batch, channels, length] = [input_shape[0], input_shape[1], input_shape[2]];
let new_length = length + pad_left + pad_right;
let output_shape = vec![batch, channels, new_length];
let mut output_data = vec![0.0f32; batch * channels * new_length];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for l in 0..new_length {
let input_idx = b * (channels * length) + c * length;
let output_idx = b * (channels * new_length) + c * new_length + l;
let src_l = if l < pad_left {
0
} else if l >= pad_left + length {
length - 1
} else {
l - pad_left
};
output_data[output_idx] = input_data[input_idx + src_l];
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn replication_pad2d(input: &Tensor, padding: (usize, usize, usize, usize)) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let (pad_left, pad_right, pad_top, pad_bottom) = padding;
let [batch, channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
let new_height = height + pad_top + pad_bottom;
let new_width = width + pad_left + pad_right;
let output_shape = vec![batch, channels, new_height, new_width];
let mut output_data = vec![0.0f32; batch * channels * new_height * new_width];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for h in 0..new_height {
for w in 0..new_width {
let src_h = if h < pad_top {
0
} else if h >= pad_top + height {
height - 1
} else {
h - pad_top
};
let src_w = if w < pad_left {
0
} else if w >= pad_left + width {
width - 1
} else {
w - pad_left
};
let input_idx = b * (channels * height * width)
+ c * (height * width)
+ src_h * width
+ src_w;
let output_idx = b * (channels * new_height * new_width)
+ c * (new_height * new_width)
+ h * new_width
+ w;
output_data[output_idx] = input_data[input_idx];
}
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn zero_pad2d(input: &Tensor, padding: (usize, usize, usize, usize)) -> Result<Tensor> {
let input_shape_obj = input.shape();
let input_shape = input_shape_obj.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::ShapeMismatch {
expected: vec![0, 0, 0, 0],
got: input_shape.to_vec(),
});
}
let (pad_left, pad_right, pad_top, pad_bottom) = padding;
let [batch, channels, height, width] = [
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
];
let new_height = height + pad_top + pad_bottom;
let new_width = width + pad_left + pad_right;
let output_shape = vec![batch, channels, new_height, new_width];
let mut output_data = vec![0.0f32; batch * channels * new_height * new_width];
let input_data = input.to_vec()?;
for b in 0..batch {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
let input_idx =
b * (channels * height * width) + c * (height * width) + h * width + w;
let output_h = h + pad_top;
let output_w = w + pad_left;
let output_idx = b * (channels * new_height * new_width)
+ c * (new_height * new_width)
+ output_h * new_width
+ output_w;
output_data[output_idx] = input_data[input_idx];
}
}
}
}
Tensor::from_vec(output_data, &output_shape)
}
pub fn pool_output_size(
input_size: usize,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> usize {
(input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
}
pub fn adaptive_pool_params(input_size: usize, output_size: usize) -> (usize, usize) {
let stride = input_size / output_size;
let kernel = input_size - (output_size - 1) * stride;
(kernel, stride)
}