#![allow(dead_code)]
use crate::{Result, VisionError};
use scirs2_core::legacy::rng;
use scirs2_core::random::Random;
use torsh_tensor::{creation::zeros_mut, Tensor};
use super::common::{utils, InterpolationMode, PaddingMode, VisionOpConfig};
pub fn resize(image: &Tensor<f32>, size: (usize, usize)) -> Result<Tensor<f32>> {
resize_with_mode(image, size, InterpolationMode::Bilinear)
}
pub fn resize_with_mode(
image: &Tensor<f32>,
size: (usize, usize),
mode: InterpolationMode,
) -> Result<Tensor<f32>> {
let (batch_size, channels, height, width) = utils::validate_image_tensor_flexible(image)?;
let (target_width, target_height) = size;
let is_batched = image.shape().dims().len() == 4;
if is_batched && batch_size == 1 {
let squeezed = image
.view(&[channels as i32, height as i32, width as i32])
.map_err(|e| VisionError::TensorError(e))?;
let resized_3d = match mode {
InterpolationMode::Bilinear => resize_bilinear(
&squeezed,
channels,
height,
width,
target_width,
target_height,
),
InterpolationMode::Nearest => resize_nearest(
&squeezed,
channels,
height,
width,
target_width,
target_height,
),
InterpolationMode::Bicubic => resize_bicubic(
&squeezed,
channels,
height,
width,
target_width,
target_height,
),
}?;
let result = resized_3d
.view(&[
1i32,
channels as i32,
target_height as i32,
target_width as i32,
])
.map_err(|e| VisionError::TensorError(e))?;
Ok(result)
} else if is_batched {
if batch_size > 1 {
return Err(VisionError::InvalidArgument(
"Batch resize with batch_size > 1 not yet supported. Use single images or loop over batch manually.".to_string(),
));
}
let single_image = image
.narrow(0, 0, 1)
.map_err(|e| VisionError::TensorError(e))?
.view(&[channels as i32, height as i32, width as i32])
.map_err(|e| VisionError::TensorError(e))?;
let resized_single = match mode {
InterpolationMode::Bilinear => resize_bilinear(
&single_image,
channels,
height,
width,
target_width,
target_height,
),
InterpolationMode::Nearest => resize_nearest(
&single_image,
channels,
height,
width,
target_width,
target_height,
),
InterpolationMode::Bicubic => resize_bicubic(
&single_image,
channels,
height,
width,
target_width,
target_height,
),
}?;
let result = resized_single
.view(&[
1i32,
channels as i32,
target_height as i32,
target_width as i32,
])
.map_err(|e| VisionError::TensorError(e))?;
Ok(result)
} else {
match mode {
InterpolationMode::Bilinear => {
resize_bilinear(image, channels, height, width, target_width, target_height)
}
InterpolationMode::Nearest => {
resize_nearest(image, channels, height, width, target_width, target_height)
}
InterpolationMode::Bicubic => {
resize_bicubic(image, channels, height, width, target_width, target_height)
}
}
}
}
fn resize_bilinear(
image: &Tensor<f32>,
channels: usize,
height: usize,
width: usize,
target_width: usize,
target_height: usize,
) -> Result<Tensor<f32>> {
let output = zeros_mut(&[channels, target_height, target_width]);
let scale_x = width as f32 / target_width as f32;
let scale_y = height as f32 / target_height as f32;
for c in 0..channels {
for y in 0..target_height {
for x in 0..target_width {
let src_x = (x as f32 + 0.5) * scale_x - 0.5;
let src_y = (y as f32 + 0.5) * scale_y - 0.5;
let x1 = src_x.floor() as usize;
let y1 = src_y.floor() as usize;
let x2 = (x1 + 1).min(width - 1);
let y2 = (y1 + 1).min(height - 1);
let (w11, w21, w12, w22) =
utils::bilinear_interpolation(src_x, src_y, x1, y1, x2, y2);
let val11 = image.get(&[c, y1, x1])?;
let val12 = image.get(&[c, y2, x1])?;
let val21 = image.get(&[c, y1, x2])?;
let val22 = image.get(&[c, y2, x2])?;
let interpolated = val11 * w11 + val21 * w21 + val12 * w12 + val22 * w22;
output.set(&[c, y, x], interpolated)?;
}
}
}
Ok(output)
}
fn resize_nearest(
image: &Tensor<f32>,
channels: usize,
height: usize,
width: usize,
target_width: usize,
target_height: usize,
) -> Result<Tensor<f32>> {
let output = zeros_mut(&[channels, target_height, target_width]);
let scale_x = width as f32 / target_width as f32;
let scale_y = height as f32 / target_height as f32;
for c in 0..channels {
for y in 0..target_height {
for x in 0..target_width {
let src_x = ((x as f32 + 0.5) * scale_x).floor() as usize;
let src_y = ((y as f32 + 0.5) * scale_y).floor() as usize;
let src_x = src_x.min(width - 1);
let src_y = src_y.min(height - 1);
let value = image.get(&[c, src_y, src_x])?;
output.set(&[c, y, x], value)?;
}
}
}
Ok(output)
}
fn resize_bicubic(
image: &Tensor<f32>,
channels: usize,
height: usize,
width: usize,
target_width: usize,
target_height: usize,
) -> Result<Tensor<f32>> {
resize_bilinear(image, channels, height, width, target_width, target_height)
}
pub fn center_crop(image: &Tensor<f32>, size: (usize, usize)) -> Result<Tensor<f32>> {
let (_channels, height, width) = utils::validate_image_tensor_3d(image)?;
let (target_width, target_height) = size;
utils::validate_crop_size(width, height, target_width, target_height)?;
let start_x = (width - target_width) / 2;
let start_y = (height - target_height) / 2;
crop_region(image, start_x, start_y, target_width, target_height)
}
pub fn random_crop(image: &Tensor<f32>, size: (usize, usize)) -> Result<Tensor<f32>> {
let (_channels, height, width) = utils::validate_image_tensor_3d(image)?;
let (target_width, target_height) = size;
utils::validate_crop_size(width, height, target_width, target_height)?;
let max_start_x = width - target_width;
let max_start_y = height - target_height;
let start_x = if max_start_x > 0 {
rng().gen_range(0..max_start_x)
} else {
0
};
let start_y = if max_start_y > 0 {
rng().gen_range(0..max_start_y)
} else {
0
};
crop_region(image, start_x, start_y, target_width, target_height)
}
pub fn crop_region(
image: &Tensor<f32>,
start_x: usize,
start_y: usize,
width: usize,
height: usize,
) -> Result<Tensor<f32>> {
let cropped = image
.narrow(1, start_y as i64, height)?
.narrow(2, start_x as i64, width)?;
Ok(cropped)
}
pub fn horizontal_flip(image: &Tensor<f32>) -> Result<Tensor<f32>> {
let (channels, height, width) = utils::validate_image_tensor_3d(image)?;
let output = zeros_mut(&[channels, height, width]);
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let src_x = width - 1 - x;
let value = image.get(&[c, y, src_x])?;
output.set(&[c, y, x], value)?;
}
}
}
Ok(output)
}
pub fn vertical_flip(image: &Tensor<f32>) -> Result<Tensor<f32>> {
let (channels, height, width) = utils::validate_image_tensor_3d(image)?;
let output = zeros_mut(&[channels, height, width]);
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let src_y = height - 1 - y;
let value = image.get(&[c, src_y, x])?;
output.set(&[c, y, x], value)?;
}
}
}
Ok(output)
}
pub fn rotate(image: &Tensor<f32>, angle: f32) -> Result<Tensor<f32>> {
let (channels, height, width) = utils::validate_image_tensor_3d(image)?;
let output = zeros_mut(&[channels, height, width]);
let center_x = width as f32 / 2.0;
let center_y = height as f32 / 2.0;
let cos_angle = angle.cos();
let sin_angle = angle.sin();
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let dx = x as f32 - center_x;
let dy = y as f32 - center_y;
let src_x = center_x + dx * cos_angle - dy * sin_angle;
let src_y = center_y + dx * sin_angle + dy * cos_angle;
if src_x >= 0.0 && src_x < width as f32 && src_y >= 0.0 && src_y < height as f32 {
let x1 = src_x.floor() as usize;
let y1 = src_y.floor() as usize;
let x2 = (x1 + 1).min(width - 1);
let y2 = (y1 + 1).min(height - 1);
let (w11, w21, w12, w22) =
utils::bilinear_interpolation(src_x, src_y, x1, y1, x2, y2);
let val11 = image.get(&[c, y1, x1])?;
let val12 = image.get(&[c, y2, x1])?;
let val21 = image.get(&[c, y1, x2])?;
let val22 = image.get(&[c, y2, x2])?;
let interpolated = val11 * w11 + val21 * w21 + val12 * w12 + val22 * w22;
output.set(&[c, y, x], interpolated)?;
}
}
}
}
Ok(output)
}
pub fn pad(
image: &Tensor<f32>,
padding: (usize, usize, usize, usize), mode: PaddingMode,
fill_value: f32,
) -> Result<Tensor<f32>> {
let (channels, height, width) = utils::validate_image_tensor_3d(image)?;
let (pad_left, pad_top, pad_right, pad_bottom) = padding;
let new_width = width + pad_left + pad_right;
let new_height = height + pad_top + pad_bottom;
let mut output = zeros_mut(&[channels, new_height, new_width]);
if mode == PaddingMode::Zero && fill_value != 0.0 {
for c in 0..channels {
for y in 0..new_height {
for x in 0..new_width {
output.set(&[c, y, x], fill_value)?;
}
}
}
}
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let value = image.get(&[c, y, x])?;
output.set(&[c, y + pad_top, x + pad_left], value)?;
}
}
}
match mode {
PaddingMode::Zero => {
}
PaddingMode::Reflect => {
apply_reflect_padding(
&mut output,
channels,
height,
width,
pad_left,
pad_top,
pad_right,
pad_bottom,
)?;
}
PaddingMode::Replicate => {
apply_replicate_padding(
&mut output,
channels,
height,
width,
pad_left,
pad_top,
pad_right,
pad_bottom,
)?;
}
PaddingMode::Circular => {
apply_circular_padding(
&mut output,
channels,
height,
width,
pad_left,
pad_top,
pad_right,
pad_bottom,
)?;
}
}
Ok(output)
}
fn apply_reflect_padding(
output: &mut Tensor<f32>,
channels: usize,
height: usize,
width: usize,
pad_left: usize,
pad_top: usize,
pad_right: usize,
pad_bottom: usize,
) -> Result<()> {
for c in 0..channels {
for y in 0..pad_top {
let src_y = pad_top - y;
for x in pad_left..(pad_left + width) {
let value = output.get(&[c, src_y, x])?;
output.set(&[c, y, x], value)?;
}
}
for y in (pad_top + height)..(pad_top + height + pad_bottom) {
let src_y = 2 * (pad_top + height) - y - 1;
for x in pad_left..(pad_left + width) {
let value = output.get(&[c, src_y, x])?;
output.set(&[c, y, x], value)?;
}
}
for x in 0..pad_left {
let src_x = pad_left - x;
for y in 0..(pad_top + height + pad_bottom) {
let value = output.get(&[c, y, src_x])?;
output.set(&[c, y, x], value)?;
}
}
for x in (pad_left + width)..(pad_left + width + pad_right) {
let src_x = 2 * (pad_left + width) - x - 1;
for y in 0..(pad_top + height + pad_bottom) {
let value = output.get(&[c, y, src_x])?;
output.set(&[c, y, x], value)?;
}
}
}
Ok(())
}
fn apply_replicate_padding(
output: &mut Tensor<f32>,
channels: usize,
height: usize,
width: usize,
pad_left: usize,
pad_top: usize,
pad_right: usize,
pad_bottom: usize,
) -> Result<()> {
for c in 0..channels {
for y in 0..pad_top {
for x in pad_left..(pad_left + width) {
let value = output.get(&[c, pad_top, x])?;
output.set(&[c, y, x], value)?;
}
}
for y in (pad_top + height)..(pad_top + height + pad_bottom) {
for x in pad_left..(pad_left + width) {
let value = output.get(&[c, pad_top + height - 1, x])?;
output.set(&[c, y, x], value)?;
}
}
for x in 0..pad_left {
for y in 0..(pad_top + height + pad_bottom) {
let value = output.get(&[c, y, pad_left])?;
output.set(&[c, y, x], value)?;
}
}
for x in (pad_left + width)..(pad_left + width + pad_right) {
for y in 0..(pad_top + height + pad_bottom) {
let value = output.get(&[c, y, pad_left + width - 1])?;
output.set(&[c, y, x], value)?;
}
}
}
Ok(())
}
fn apply_circular_padding(
output: &mut Tensor<f32>,
channels: usize,
height: usize,
width: usize,
pad_left: usize,
pad_top: usize,
pad_right: usize,
pad_bottom: usize,
) -> Result<()> {
for c in 0..channels {
for y in 0..pad_top {
let src_y = height - (pad_top - y);
for x in pad_left..(pad_left + width) {
let value = output.get(&[c, src_y, x])?;
output.set(&[c, y, x], value)?;
}
}
for y in (pad_top + height)..(pad_top + height + pad_bottom) {
let src_y = y - height;
for x in pad_left..(pad_left + width) {
let value = output.get(&[c, src_y, x])?;
output.set(&[c, y, x], value)?;
}
}
for x in 0..pad_left {
let src_x = width - (pad_left - x);
for y in 0..(pad_top + height + pad_bottom) {
let value = output.get(&[c, y, src_x])?;
output.set(&[c, y, x], value)?;
}
}
for x in (pad_left + width)..(pad_left + width + pad_right) {
let src_x = x - width;
for y in 0..(pad_top + height + pad_bottom) {
let value = output.get(&[c, y, src_x])?;
output.set(&[c, y, x], value)?;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::ones;
#[test]
fn test_resize() {
let image = ones(&[3, 4, 4]).expect("ones should succeed");
let resized = resize(&image, (8, 8)).expect("operation should succeed");
assert_eq!(resized.shape().dims(), &[3, 8, 8]);
}
#[test]
fn test_center_crop() {
let image = ones(&[3, 10, 10]).expect("ones should succeed");
let cropped = center_crop(&image, (6, 6)).expect("operation should succeed");
assert_eq!(cropped.shape().dims(), &[3, 6, 6]);
}
#[test]
fn test_horizontal_flip() {
let image = ones(&[3, 4, 4]).expect("ones should succeed");
let flipped = horizontal_flip(&image).expect("horizontal flip should succeed");
assert_eq!(flipped.shape().dims(), &[3, 4, 4]);
}
#[test]
fn test_padding() {
let image = ones(&[3, 4, 4]).expect("ones should succeed");
let padded =
pad(&image, (1, 1, 1, 1), PaddingMode::Zero, 0.0).expect("operation should succeed");
assert_eq!(padded.shape().dims(), &[3, 6, 6]);
}
#[test]
fn test_rotation() {
let image = ones(&[3, 4, 4]).expect("ones should succeed");
let rotated = rotate(&image, std::f32::consts::PI / 4.0).expect("rotate should succeed");
assert_eq!(rotated.shape().dims(), &[3, 4, 4]);
}
}