use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
use torsh_core::error::Result;
use torsh_tensor::{creation::*, Tensor};
pub struct PixelShuffle {
base: ModuleBase,
upscale_factor: usize,
}
impl PixelShuffle {
pub fn new(upscale_factor: usize) -> Self {
Self {
base: ModuleBase::new(),
upscale_factor,
}
}
pub fn upscale_factor(&self) -> usize {
self.upscale_factor
}
}
impl Module for PixelShuffle {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let binding = input.shape();
let input_shape = binding.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"PixelShuffle expects 4D input (N, C, H, W)".to_string(),
));
}
let batch_size = input_shape[0];
let channels_in = input_shape[1];
let height_in = input_shape[2];
let width_in = input_shape[3];
let r = self.upscale_factor;
let r_squared = r * r;
if channels_in % r_squared != 0 {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Input channels {} must be divisible by upscale_factor^2 = {}",
channels_in, r_squared
)));
}
let channels_out = channels_in / r_squared;
let height_out = height_in * r;
let width_out = width_in * r;
let input_data = input.to_vec()?;
let mut output_data =
vec![input_data[0]; batch_size * channels_out * height_out * width_out];
for b in 0..batch_size {
for c in 0..channels_out {
for h in 0..height_in {
for w in 0..width_in {
for ry in 0..r {
for rx in 0..r {
let c_in = c * r_squared + ry * r + rx;
let in_idx =
((b * channels_in + c_in) * height_in + h) * width_in + w;
let h_out = h * r + ry;
let w_out = w * r + rx;
let out_idx = ((b * channels_out + c) * height_out + h_out)
* width_out
+ w_out;
output_data[out_idx] = input_data[in_idx];
}
}
}
}
}
}
let output_shape = [batch_size, channels_out, height_out, width_out];
Tensor::from_data(output_data, output_shape.to_vec(), input.device())
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct PixelUnshuffle {
base: ModuleBase,
downscale_factor: usize,
}
impl PixelUnshuffle {
pub fn new(downscale_factor: usize) -> Self {
Self {
base: ModuleBase::new(),
downscale_factor,
}
}
pub fn downscale_factor(&self) -> usize {
self.downscale_factor
}
}
impl Module for PixelUnshuffle {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let binding = input.shape();
let input_shape = binding.dims();
if input_shape.len() != 4 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"PixelUnshuffle expects 4D input (N, C, H, W)".to_string(),
));
}
let batch_size = input_shape[0];
let channels_in = input_shape[1];
let height_in = input_shape[2];
let width_in = input_shape[3];
let r = self.downscale_factor;
if height_in % r != 0 || width_in % r != 0 {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Input spatial dimensions ({}, {}) must be divisible by downscale_factor {}",
height_in, width_in, r
)));
}
let r_squared = r * r;
let channels_out = channels_in * r_squared;
let height_out = height_in / r;
let width_out = width_in / r;
let input_data = input.to_vec()?;
let mut output_data =
vec![input_data[0]; batch_size * channels_out * height_out * width_out];
for b in 0..batch_size {
for c in 0..channels_in {
for h in 0..height_out {
for w in 0..width_out {
for ry in 0..r {
for rx in 0..r {
let h_in = h * r + ry;
let w_in = w * r + rx;
let in_idx =
((b * channels_in + c) * height_in + h_in) * width_in + w_in;
let c_out = c * r_squared + ry * r + rx;
let out_idx =
((b * channels_out + c_out) * height_out + h) * width_out + w;
output_data[out_idx] = input_data[in_idx];
}
}
}
}
}
}
let output_shape = [batch_size, channels_out, height_out, width_out];
Tensor::from_data(output_data, output_shape.to_vec(), input.device())
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct PixelShuffle1d {
base: ModuleBase,
upscale_factor: usize,
}
impl PixelShuffle1d {
pub fn new(upscale_factor: usize) -> Self {
Self {
base: ModuleBase::new(),
upscale_factor,
}
}
pub fn upscale_factor(&self) -> usize {
self.upscale_factor
}
}
impl Module for PixelShuffle1d {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let binding = input.shape();
let input_shape = binding.dims();
if input_shape.len() != 3 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"PixelShuffle1d expects 3D input (N, C, L)".to_string(),
));
}
let batch_size = input_shape[0];
let channels_in = input_shape[1];
let length_in = input_shape[2];
let r = self.upscale_factor;
if channels_in % r != 0 {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Input channels {} must be divisible by upscale_factor {}",
channels_in, r
)));
}
let channels_out = channels_in / r;
let length_out = length_in * r;
let output_shape = [batch_size, channels_out, length_out];
let output = zeros(&output_shape)?;
Ok(output)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
pub struct PixelUnshuffle1d {
base: ModuleBase,
downscale_factor: usize,
}
impl PixelUnshuffle1d {
pub fn new(downscale_factor: usize) -> Self {
Self {
base: ModuleBase::new(),
downscale_factor,
}
}
pub fn downscale_factor(&self) -> usize {
self.downscale_factor
}
}
impl Module for PixelUnshuffle1d {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let binding = input.shape();
let input_shape = binding.dims();
if input_shape.len() != 3 {
return Err(torsh_core::error::TorshError::InvalidArgument(
"PixelUnshuffle1d expects 3D input (N, C, L)".to_string(),
));
}
let batch_size = input_shape[0];
let channels_in = input_shape[1];
let length_in = input_shape[2];
let r = self.downscale_factor;
if length_in % r != 0 {
return Err(torsh_core::error::TorshError::InvalidArgument(format!(
"Input length {} must be divisible by downscale_factor {}",
length_in, r
)));
}
let channels_out = channels_in * r;
let length_out = length_in / r;
let output_shape = [batch_size, channels_out, length_out];
let output = zeros(&output_shape)?;
Ok(output)
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.base.parameters.clone()
}
fn training(&self) -> bool {
self.base.training()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.base.named_parameters()
}
}
impl std::fmt::Debug for PixelShuffle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PixelShuffle")
.field("upscale_factor", &self.upscale_factor)
.finish()
}
}
impl std::fmt::Debug for PixelUnshuffle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PixelUnshuffle")
.field("downscale_factor", &self.downscale_factor)
.finish()
}
}
impl std::fmt::Debug for PixelShuffle1d {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PixelShuffle1d")
.field("upscale_factor", &self.upscale_factor)
.finish()
}
}
impl std::fmt::Debug for PixelUnshuffle1d {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PixelUnshuffle1d")
.field("downscale_factor", &self.downscale_factor)
.finish()
}
}
pub mod utils {
pub fn pixel_shuffle_output_size(
input_size: (usize, usize),
upscale_factor: usize,
) -> (usize, usize) {
(input_size.0 * upscale_factor, input_size.1 * upscale_factor)
}
pub fn pixel_unshuffle_output_size(
input_size: (usize, usize),
downscale_factor: usize,
) -> Result<(usize, usize), String> {
if input_size.0 % downscale_factor != 0 || input_size.1 % downscale_factor != 0 {
return Err(format!(
"Input size {:?} must be divisible by downscale factor {}",
input_size, downscale_factor
));
}
Ok((
input_size.0 / downscale_factor,
input_size.1 / downscale_factor,
))
}
pub fn pixel_shuffle_input_channels(output_channels: usize, upscale_factor: usize) -> usize {
output_channels * upscale_factor * upscale_factor
}
pub fn pixel_unshuffle_output_channels(
input_channels: usize,
downscale_factor: usize,
) -> usize {
input_channels * downscale_factor * downscale_factor
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::{ones, zeros};
#[test]
fn test_pixel_shuffle_creation() {
let layer = PixelShuffle::new(2);
assert_eq!(layer.upscale_factor(), 2);
}
#[test]
fn test_pixel_unshuffle_creation() {
let layer = PixelUnshuffle::new(2);
assert_eq!(layer.downscale_factor(), 2);
}
#[test]
fn test_pixel_shuffle_output_shape() -> std::result::Result<(), Box<dyn std::error::Error>> {
let layer = PixelShuffle::new(2);
let input = zeros(&[1, 12, 16, 16])?; let result = layer.forward(&input);
assert!(result.is_ok());
let output = result?;
let binding = output.shape();
let output_shape = binding.dims();
assert_eq!(output_shape, &[1, 3, 32, 32]); Ok(())
}
#[test]
fn test_pixel_unshuffle_output_shape() -> std::result::Result<(), Box<dyn std::error::Error>> {
let layer = PixelUnshuffle::new(2);
let input = zeros(&[1, 3, 32, 32])?; let result = layer.forward(&input);
assert!(result.is_ok());
let output = result?;
let binding = output.shape();
let output_shape = binding.dims();
assert_eq!(output_shape, &[1, 12, 16, 16]); Ok(())
}
#[test]
fn test_pixel_shuffle_invalid_channels() -> std::result::Result<(), Box<dyn std::error::Error>>
{
let layer = PixelShuffle::new(2);
let input = zeros(&[1, 11, 16, 16])?; let result = layer.forward(&input);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_pixel_unshuffle_invalid_dimensions(
) -> std::result::Result<(), Box<dyn std::error::Error>> {
let layer = PixelUnshuffle::new(3);
let input = zeros(&[1, 3, 16, 17])?; let result = layer.forward(&input);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_pixel_shuffle_upscale_factor_3() -> std::result::Result<(), Box<dyn std::error::Error>>
{
let layer = PixelShuffle::new(3);
let input = zeros(&[2, 27, 8, 8])?; let output = layer.forward(&input)?;
let binding = output.shape();
let output_shape = binding.dims();
assert_eq!(output_shape, &[2, 3, 24, 24]); Ok(())
}
#[test]
fn test_pixel_shuffle_upscale_factor_4() -> std::result::Result<(), Box<dyn std::error::Error>>
{
let layer = PixelShuffle::new(4);
let input = zeros(&[1, 64, 4, 4])?; let output = layer.forward(&input)?;
let binding = output.shape();
let output_shape = binding.dims();
assert_eq!(output_shape, &[1, 4, 16, 16]); Ok(())
}
#[test]
fn test_pixel_unshuffle_downscale_factor_3(
) -> std::result::Result<(), Box<dyn std::error::Error>> {
let layer = PixelUnshuffle::new(3);
let input = zeros(&[2, 3, 24, 24])?; let output = layer.forward(&input)?;
let binding = output.shape();
let output_shape = binding.dims();
assert_eq!(output_shape, &[2, 27, 8, 8]); Ok(())
}
#[test]
fn test_pixel_shuffle_round_trip() -> std::result::Result<(), Box<dyn std::error::Error>> {
let shuffle = PixelShuffle::new(2);
let unshuffle = PixelUnshuffle::new(2);
let input = ones(&[1, 12, 8, 8])?;
let shuffled = shuffle.forward(&input)?;
let unshuffled = unshuffle.forward(&shuffled)?;
let input_binding = input.shape();
let input_shape = input_binding.dims();
let output_binding = unshuffled.shape();
let output_shape = output_binding.dims();
assert_eq!(input_shape, output_shape);
Ok(())
}
#[test]
fn test_pixel_shuffle_invalid_spatial_dim(
) -> std::result::Result<(), Box<dyn std::error::Error>> {
let layer = PixelShuffle::new(2);
let input = zeros(&[1, 12, 16])?; let result = layer.forward(&input);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_pixel_unshuffle_invalid_spatial_dim(
) -> std::result::Result<(), Box<dyn std::error::Error>> {
let layer = PixelUnshuffle::new(2);
let input = zeros(&[1, 3, 32])?; let result = layer.forward(&input);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_pixel_shuffle_shape_preservation() -> std::result::Result<(), Box<dyn std::error::Error>>
{
let layer = PixelShuffle::new(2);
let input = zeros(&[1, 12, 16, 16])?;
let output = layer.forward(&input)?;
assert_eq!(input.numel(), output.numel());
Ok(())
}
#[test]
fn test_pixel_unshuffle_shape_preservation(
) -> std::result::Result<(), Box<dyn std::error::Error>> {
let layer = PixelUnshuffle::new(2);
let input = zeros(&[1, 3, 32, 32])?;
let output = layer.forward(&input)?;
assert_eq!(input.numel(), output.numel());
Ok(())
}
#[test]
fn test_pixel_shuffle_module_trait() -> std::result::Result<(), Box<dyn std::error::Error>> {
let mut layer = PixelShuffle::new(2);
assert!(layer.training()); layer.eval();
assert!(!layer.training());
layer.train();
assert!(layer.training());
let params = layer.parameters();
assert!(params.is_empty());
Ok(())
}
#[test]
fn test_pixel_unshuffle_module_trait() -> std::result::Result<(), Box<dyn std::error::Error>> {
let mut layer = PixelUnshuffle::new(3);
assert!(layer.training()); layer.eval();
assert!(!layer.training());
layer.train();
assert!(layer.training());
let params = layer.parameters();
assert!(params.is_empty());
Ok(())
}
#[test]
fn test_pixel_shuffle_data_correctness() -> std::result::Result<(), Box<dyn std::error::Error>>
{
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let input_data: Vec<f32> = (0..16).map(|x| x as f32).collect();
let input = Tensor::from_data(input_data, vec![1, 4, 2, 2], DeviceType::Cpu)?;
let layer = PixelShuffle::new(2);
let output = layer.forward(&input)?;
let binding = output.shape();
let output_shape = binding.dims();
assert_eq!(output_shape, &[1, 1, 4, 4]);
let output_data = output.to_vec()?;
assert_eq!(output_data.len(), 16);
Ok(())
}
#[test]
fn test_pixel_unshuffle_data_correctness() -> std::result::Result<(), Box<dyn std::error::Error>>
{
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let input_data: Vec<f32> = (0..16).map(|x| x as f32).collect();
let input = Tensor::from_data(input_data, vec![1, 1, 4, 4], DeviceType::Cpu)?;
let layer = PixelUnshuffle::new(2);
let output = layer.forward(&input)?;
let binding = output.shape();
let output_shape = binding.dims();
assert_eq!(output_shape, &[1, 4, 2, 2]);
let output_data = output.to_vec()?;
assert_eq!(output_data.len(), 16);
Ok(())
}
#[test]
fn test_pixel_shuffle_unshuffle_inverse() -> std::result::Result<(), Box<dyn std::error::Error>>
{
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let input_data: Vec<f32> = (0..48).map(|x| x as f32).collect();
let input = Tensor::from_data(input_data.clone(), vec![1, 12, 2, 2], DeviceType::Cpu)?;
let shuffle = PixelShuffle::new(2);
let unshuffle = PixelUnshuffle::new(2);
let shuffled = shuffle.forward(&input)?;
let result = unshuffle.forward(&shuffled)?;
let result_data = result.to_vec()?;
let binding = result.shape();
let result_shape = binding.dims();
let binding2 = input.shape();
let input_shape = binding2.dims();
assert_eq!(result_shape, input_shape);
assert_eq!(result_data.len(), input_data.len());
for (a, b) in result_data.iter().zip(input_data.iter()) {
assert!((a - b).abs() < 1e-6);
}
Ok(())
}
#[test]
fn test_utils_functions() {
use super::utils::*;
assert_eq!(pixel_shuffle_output_size((16, 16), 2), (32, 32));
assert_eq!(pixel_unshuffle_output_size((32, 32), 2), Ok((16, 16)));
assert!(pixel_unshuffle_output_size((17, 16), 2).is_err());
assert_eq!(pixel_shuffle_input_channels(3, 2), 12);
assert_eq!(pixel_unshuffle_output_channels(3, 2), 12);
}
}