use std::sync::Arc;
use ferrotorch_core::autograd::no_grad::is_grad_enabled;
use ferrotorch_core::ops::linalg::{mm, transpose};
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::tensor::{GradFn, Tensor};
use ferrotorch_core::{Float, FerrotorchError, FerrotorchResult};
use crate::init::{kaiming_uniform, zeros as zeros_init, NonLinearity};
use crate::module::Module;
use crate::parameter::Parameter;
fn im2col<T: Float>(
input: &[T],
batch: usize,
channels: usize,
height: usize,
width: usize,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
) -> (Vec<T>, usize, usize) {
let h_out = (height + 2 * pad_h - kernel_h) / stride_h + 1;
let w_out = (width + 2 * pad_w - kernel_w) / stride_w + 1;
let col_rows = channels * kernel_h * kernel_w;
let col_cols = h_out * w_out;
let zero = <T as num_traits::Zero>::zero();
let mut cols = vec![zero; batch * col_rows * col_cols];
for b in 0..batch {
for c in 0..channels {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let row = c * kernel_h * kernel_w + kh * kernel_w + kw;
for oh in 0..h_out {
for ow in 0..w_out {
let ih = oh * stride_h + kh;
let iw = ow * stride_w + kw;
let col = oh * w_out + ow;
let val = if ih >= pad_h
&& iw >= pad_w
&& (ih - pad_h) < height
&& (iw - pad_w) < width
{
let real_h = ih - pad_h;
let real_w = iw - pad_w;
input[b * channels * height * width
+ c * height * width
+ real_h * width
+ real_w]
} else {
zero
};
cols[b * col_rows * col_cols + row * col_cols + col] = val;
}
}
}
}
}
}
(cols, col_rows, col_cols)
}
fn col2im<T: Float>(
cols: &[T],
batch: usize,
channels: usize,
height: usize,
width: usize,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
h_out: usize,
w_out: usize,
) -> Vec<T> {
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * channels * height * width];
let col_rows = channels * kernel_h * kernel_w;
let col_cols = h_out * w_out;
for b in 0..batch {
for c in 0..channels {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let row = c * kernel_h * kernel_w + kh * kernel_w + kw;
for oh in 0..h_out {
for ow in 0..w_out {
let ih = oh * stride_h + kh;
let iw = ow * stride_w + kw;
let col = oh * w_out + ow;
if ih >= pad_h
&& iw >= pad_w
&& (ih - pad_h) < height
&& (iw - pad_w) < width
{
let real_h = ih - pad_h;
let real_w = iw - pad_w;
output[b * channels * height * width
+ c * height * width
+ real_h * width
+ real_w] = output[b * channels * height * width
+ c * height * width
+ real_h * width
+ real_w]
+ cols[b * col_rows * col_cols + row * col_cols + col];
}
}
}
}
}
}
}
output
}
#[derive(Debug)]
pub struct Conv2d<T: Float> {
weight: Parameter<T>,
bias: Option<Parameter<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
training: bool,
}
impl<T: Float> Conv2d<T> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
bias: bool,
) -> FerrotorchResult<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "in_channels and out_channels must be > 0".into(),
});
}
if kernel_size.0 == 0 || kernel_size.1 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "kernel_size must be > 0 in both dimensions".into(),
});
}
if stride.0 == 0 || stride.1 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "stride must be > 0 in both dimensions".into(),
});
}
let (kh, kw) = kernel_size;
let mut weight =
Parameter::zeros(&[out_channels, in_channels, kh, kw])?;
kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
let bias_param = if bias {
let mut b = Parameter::zeros(&[out_channels])?;
zeros_init(&mut b)?;
Some(b)
} else {
None
};
Ok(Self {
weight,
bias: bias_param,
in_channels,
out_channels,
kernel_size,
stride,
padding,
training: true,
})
}
pub fn num_parameters(&self) -> usize {
let w = self.out_channels * self.in_channels * self.kernel_size.0 * self.kernel_size.1;
let b = if self.bias.is_some() {
self.out_channels
} else {
0
};
w + b
}
}
impl<T: Float> Module<T> for Conv2d<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv2d expects 4-D input [B, C, H, W], got {:?}",
input.shape()
),
});
}
let batch = input.shape()[0];
let c_in = input.shape()[1];
let h = input.shape()[2];
let w = input.shape()[3];
if c_in != self.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"Conv2d: expected {} input channels, got {}",
self.in_channels, c_in
),
});
}
let (kh, kw) = self.kernel_size;
let (sh, sw) = self.stride;
let (ph, pw) = self.padding;
let h_padded = h + 2 * ph;
let w_padded = w + 2 * pw;
if h_padded < kh || w_padded < kw {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv2d: padded input ({h_padded}, {w_padded}) is smaller than kernel ({kh}, {kw})"
),
});
}
let h_out = (h_padded - kh) / sh + 1;
let w_out = (w_padded - kw) / sw + 1;
let input_data = input.data()?;
let (cols, col_rows, col_cols) =
im2col(input_data, batch, c_in, h, w, kh, kw, sh, sw, ph, pw);
let weight_data = self.weight.data()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data.to_vec()),
vec![self.out_channels, col_rows],
false,
)?;
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * self.out_channels * h_out * w_out];
for b in 0..batch {
let col_start = b * col_rows * col_cols;
let col_end = col_start + col_rows * col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(cols[col_start..col_end].to_vec()),
vec![col_rows, col_cols],
false,
)?;
let out_b = mm(&weight_2d, &cols_b)?;
let out_data = out_b.data()?;
let out_start = b * self.out_channels * h_out * w_out;
output[out_start..out_start + self.out_channels * h_out * w_out]
.copy_from_slice(out_data);
}
if let Some(ref bias) = self.bias {
let bias_data = bias.data()?;
for b in 0..batch {
for c in 0..self.out_channels {
let bval = bias_data[c];
for hw in 0..(h_out * w_out) {
output[b * self.out_channels * h_out * w_out
+ c * h_out * w_out
+ hw] += bval;
}
}
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, self.out_channels, h_out, w_out],
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let grad_fn = Arc::new(Conv2dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
cols,
col_rows,
col_cols,
h_out,
w_out,
});
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)
} else {
Ok(result)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.weight];
if let Some(ref b) = self.bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = vec![&mut self.weight];
if let Some(ref mut b) = self.bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![("weight".to_string(), &self.weight)];
if let Some(ref b) = self.bias {
params.push(("bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
struct Conv2dBackward<T: Float> {
input: Tensor<T>,
weight: Tensor<T>,
bias: Option<Tensor<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
cols: Vec<T>,
col_rows: usize,
col_cols: usize,
h_out: usize,
w_out: usize,
}
impl<T: Float> GradFn<T> for Conv2dBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data()?;
let batch = self.input.shape()[0];
let h = self.input.shape()[2];
let w = self.input.shape()[3];
let (kh, kw) = self.kernel_size;
let (sh, sw) = self.stride;
let (ph, pw) = self.padding;
let grad_weight = if self.weight.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let weight_numel = self.out_channels * self.col_rows;
let mut gw_accum = vec![zero; weight_numel];
for b in 0..batch {
let go_start = b * self.out_channels * self.h_out * self.w_out;
let go_end = go_start + self.out_channels * self.h_out * self.w_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, self.h_out * self.w_out],
false,
)?;
let col_start = b * self.col_rows * self.col_cols;
let col_end = col_start + self.col_rows * self.col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(self.cols[col_start..col_end].to_vec()),
vec![self.col_rows, self.col_cols],
false,
)?;
let cols_bt = transpose(&cols_b)?;
let gw_b = mm(&go_b, &cols_bt)?;
let gw_data = gw_b.data()?;
for i in 0..weight_numel {
gw_accum[i] += gw_data[i];
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gw_accum),
vec![
self.out_channels,
self.in_channels,
kh,
kw,
],
false,
)?)
} else {
None
};
let grad_bias = match &self.bias {
Some(b) if b.requires_grad() => {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; self.out_channels];
for batch_idx in 0..batch {
for c in 0..self.out_channels {
for hw in 0..(self.h_out * self.w_out) {
gb[c] += go_data[batch_idx * self.out_channels * self.h_out * self.w_out
+ c * self.h_out * self.w_out
+ hw];
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gb),
vec![self.out_channels],
false,
)?)
}
_ => None,
};
let grad_input = if self.input.requires_grad() {
let weight_data = self.weight.data()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data.to_vec()),
vec![self.out_channels, self.col_rows],
false,
)?;
let weight_2d_t = transpose(&weight_2d)?;
let zero = <T as num_traits::Zero>::zero();
let mut grad_cols = vec![zero; batch * self.col_rows * self.col_cols];
for b in 0..batch {
let go_start = b * self.out_channels * self.h_out * self.w_out;
let go_end = go_start + self.out_channels * self.h_out * self.w_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, self.h_out * self.w_out],
false,
)?;
let gc_b = mm(&weight_2d_t, &go_b)?;
let gc_data = gc_b.data()?;
let gc_start = b * self.col_rows * self.col_cols;
grad_cols[gc_start..gc_start + self.col_rows * self.col_cols]
.copy_from_slice(gc_data);
}
let gi = col2im(
&grad_cols,
batch,
self.in_channels,
h,
w,
kh,
kw,
sh,
sw,
ph,
pw,
self.h_out,
self.w_out,
);
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![grad_input, grad_weight, grad_bias])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"Conv2dBackward"
}
}
#[derive(Debug)]
pub struct Conv1d<T: Float> {
weight: Parameter<T>,
bias: Option<Parameter<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
training: bool,
}
impl<T: Float> Conv1d<T> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
bias: bool,
) -> FerrotorchResult<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "in_channels and out_channels must be > 0".into(),
});
}
if kernel_size == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "kernel_size must be > 0".into(),
});
}
if stride == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "stride must be > 0".into(),
});
}
let mut weight =
Parameter::zeros(&[out_channels, in_channels, kernel_size])?;
kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
let bias_param = if bias {
let mut b = Parameter::zeros(&[out_channels])?;
zeros_init(&mut b)?;
Some(b)
} else {
None
};
Ok(Self {
weight,
bias: bias_param,
in_channels,
out_channels,
kernel_size,
stride,
padding,
training: true,
})
}
pub fn num_parameters(&self) -> usize {
let w = self.out_channels * self.in_channels * self.kernel_size;
let b = if self.bias.is_some() {
self.out_channels
} else {
0
};
w + b
}
}
impl<T: Float> Module<T> for Conv1d<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv1d expects 3-D input [B, C, L], got {:?}",
input.shape()
),
});
}
let batch = input.shape()[0];
let c_in = input.shape()[1];
let length = input.shape()[2];
if c_in != self.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"Conv1d: expected {} input channels, got {}",
self.in_channels, c_in
),
});
}
let k = self.kernel_size;
let s = self.stride;
let p = self.padding;
let l_padded = length + 2 * p;
if l_padded < k {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv1d: padded input length ({l_padded}) is smaller than kernel ({k})"
),
});
}
let l_out = (l_padded - k) / s + 1;
let input_data = input.data()?;
let (cols, col_rows, col_cols) =
im2col(input_data, batch, c_in, 1, length, 1, k, 1, s, 0, p);
let weight_data = self.weight.data()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data.to_vec()),
vec![self.out_channels, col_rows],
false,
)?;
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * self.out_channels * l_out];
for b in 0..batch {
let col_start = b * col_rows * col_cols;
let col_end = col_start + col_rows * col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(cols[col_start..col_end].to_vec()),
vec![col_rows, col_cols],
false,
)?;
let out_b = mm(&weight_2d, &cols_b)?;
let out_data = out_b.data()?;
let out_start = b * self.out_channels * l_out;
output[out_start..out_start + self.out_channels * l_out]
.copy_from_slice(out_data);
}
if let Some(ref bias) = self.bias {
let bias_data = bias.data()?;
for b in 0..batch {
for c in 0..self.out_channels {
let bval = bias_data[c];
for l in 0..l_out {
output[b * self.out_channels * l_out + c * l_out + l] += bval;
}
}
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, self.out_channels, l_out],
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let grad_fn = Arc::new(Conv1dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
cols,
col_rows,
col_cols,
l_out,
});
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)
} else {
Ok(result)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.weight];
if let Some(ref b) = self.bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = vec![&mut self.weight];
if let Some(ref mut b) = self.bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![("weight".to_string(), &self.weight)];
if let Some(ref b) = self.bias {
params.push(("bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
struct Conv1dBackward<T: Float> {
input: Tensor<T>,
weight: Tensor<T>,
bias: Option<Tensor<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
cols: Vec<T>,
col_rows: usize,
col_cols: usize,
l_out: usize,
}
impl<T: Float> GradFn<T> for Conv1dBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data()?;
let batch = self.input.shape()[0];
let length = self.input.shape()[2];
let k = self.kernel_size;
let s = self.stride;
let p = self.padding;
let grad_weight = if self.weight.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let weight_numel = self.out_channels * self.col_rows;
let mut gw_accum = vec![zero; weight_numel];
for b in 0..batch {
let go_start = b * self.out_channels * self.l_out;
let go_end = go_start + self.out_channels * self.l_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, self.l_out],
false,
)?;
let col_start = b * self.col_rows * self.col_cols;
let col_end = col_start + self.col_rows * self.col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(self.cols[col_start..col_end].to_vec()),
vec![self.col_rows, self.col_cols],
false,
)?;
let cols_bt = transpose(&cols_b)?;
let gw_b = mm(&go_b, &cols_bt)?;
let gw_data = gw_b.data()?;
for i in 0..weight_numel {
gw_accum[i] += gw_data[i];
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gw_accum),
vec![self.out_channels, self.in_channels, k],
false,
)?)
} else {
None
};
let grad_bias = match &self.bias {
Some(b) if b.requires_grad() => {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; self.out_channels];
for batch_idx in 0..batch {
for c in 0..self.out_channels {
for l in 0..self.l_out {
gb[c] += go_data[batch_idx * self.out_channels * self.l_out
+ c * self.l_out
+ l];
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gb),
vec![self.out_channels],
false,
)?)
}
_ => None,
};
let grad_input = if self.input.requires_grad() {
let weight_data = self.weight.data()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data.to_vec()),
vec![self.out_channels, self.col_rows],
false,
)?;
let weight_2d_t = transpose(&weight_2d)?;
let zero = <T as num_traits::Zero>::zero();
let mut grad_cols = vec![zero; batch * self.col_rows * self.col_cols];
for b in 0..batch {
let go_start = b * self.out_channels * self.l_out;
let go_end = go_start + self.out_channels * self.l_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, self.l_out],
false,
)?;
let gc_b = mm(&weight_2d_t, &go_b)?;
let gc_data = gc_b.data()?;
let gc_start = b * self.col_rows * self.col_cols;
grad_cols[gc_start..gc_start + self.col_rows * self.col_cols]
.copy_from_slice(gc_data);
}
let gi = col2im(
&grad_cols,
batch,
self.in_channels,
1,
length,
1,
k,
1,
s,
0,
p,
1,
self.l_out,
);
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![grad_input, grad_weight, grad_bias])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"Conv1dBackward"
}
}
#[derive(Debug)]
pub struct ConvTranspose2d<T: Float> {
weight: Parameter<T>,
bias: Option<Parameter<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
output_padding: (usize, usize),
training: bool,
}
impl<T: Float> ConvTranspose2d<T> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
output_padding: (usize, usize),
bias: bool,
) -> FerrotorchResult<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "in_channels and out_channels must be > 0".into(),
});
}
if kernel_size.0 == 0 || kernel_size.1 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "kernel_size must be > 0 in both dimensions".into(),
});
}
if stride.0 == 0 || stride.1 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "stride must be > 0 in both dimensions".into(),
});
}
if output_padding.0 >= stride.0 || output_padding.1 >= stride.1 {
return Err(FerrotorchError::InvalidArgument {
message: "output_padding must be strictly less than stride".into(),
});
}
let (kh, kw) = kernel_size;
let mut weight =
Parameter::zeros(&[in_channels, out_channels, kh, kw])?;
kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
let bias_param = if bias {
let mut b = Parameter::zeros(&[out_channels])?;
zeros_init(&mut b)?;
Some(b)
} else {
None
};
Ok(Self {
weight,
bias: bias_param,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
training: true,
})
}
pub fn num_parameters(&self) -> usize {
let w = self.in_channels * self.out_channels * self.kernel_size.0 * self.kernel_size.1;
let b = if self.bias.is_some() {
self.out_channels
} else {
0
};
w + b
}
}
fn stride_insert_zeros<T: Float>(
input: &[T],
batch: usize,
channels: usize,
h: usize,
w: usize,
stride_h: usize,
stride_w: usize,
) -> (Vec<T>, usize, usize) {
let h_up = (h - 1) * stride_h + 1;
let w_up = (w - 1) * stride_w + 1;
let zero = <T as num_traits::Zero>::zero();
let mut out = vec![zero; batch * channels * h_up * w_up];
for b in 0..batch {
for c in 0..channels {
for ih in 0..h {
for iw in 0..w {
let oh = ih * stride_h;
let ow = iw * stride_w;
out[b * channels * h_up * w_up + c * h_up * w_up + oh * w_up + ow] =
input[b * channels * h * w + c * h * w + ih * w + iw];
}
}
}
}
(out, h_up, w_up)
}
fn flip_kernel<T: Float>(
kernel: &[T],
c_in: usize,
c_out: usize,
kh: usize,
kw: usize,
) -> Vec<T> {
let zero = <T as num_traits::Zero>::zero();
let mut flipped = vec![zero; c_out * c_in * kh * kw];
for ci in 0..c_in {
for co in 0..c_out {
for h in 0..kh {
for w in 0..kw {
let src = ci * c_out * kh * kw + co * kh * kw + h * kw + w;
let dst = co * c_in * kh * kw
+ ci * kh * kw
+ (kh - 1 - h) * kw
+ (kw - 1 - w);
flipped[dst] = kernel[src];
}
}
}
}
flipped
}
impl<T: Float> Module<T> for ConvTranspose2d<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ConvTranspose2d expects 4-D input [B, C, H, W], got {:?}",
input.shape()
),
});
}
let batch = input.shape()[0];
let c_in = input.shape()[1];
let h = input.shape()[2];
let w = input.shape()[3];
if c_in != self.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"ConvTranspose2d: expected {} input channels, got {}",
self.in_channels, c_in
),
});
}
let (kh, kw) = self.kernel_size;
let (sh, sw) = self.stride;
let (ph, pw) = self.padding;
let (oph, opw) = self.output_padding;
let input_data = input.data()?;
let (upsampled, h_up, w_up) =
stride_insert_zeros(input_data, batch, c_in, h, w, sh, sw);
let weight_data = self.weight.data()?;
let flipped = flip_kernel(weight_data, self.in_channels, self.out_channels, kh, kw);
let internal_pad_h = kh - 1 - ph;
let internal_pad_w = kw - 1 - pw;
let (cols, col_rows, col_cols) = im2col(
&upsampled,
batch,
c_in,
h_up,
w_up,
kh,
kw,
1,
1,
internal_pad_h,
internal_pad_w,
);
let h_out_base = (h_up + 2 * internal_pad_h - kh) + 1;
let w_out_base = (w_up + 2 * internal_pad_w - kw) + 1;
let h_out = h_out_base + oph;
let w_out = w_out_base + opw;
let flipped_2d = Tensor::from_storage(
TensorStorage::cpu(flipped),
vec![self.out_channels, col_rows],
false,
)?;
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * self.out_channels * h_out * w_out];
for b in 0..batch {
let col_start = b * col_rows * col_cols;
let col_end = col_start + col_rows * col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(cols[col_start..col_end].to_vec()),
vec![col_rows, col_cols],
false,
)?;
let out_b = mm(&flipped_2d, &cols_b)?;
let out_data = out_b.data()?;
let out_start = b * self.out_channels * h_out * w_out;
for c in 0..self.out_channels {
for oh in 0..h_out_base {
for ow in 0..w_out_base {
output[out_start + c * h_out * w_out + oh * w_out + ow] =
out_data[c * h_out_base * w_out_base + oh * w_out_base + ow];
}
}
}
}
if let Some(ref bias) = self.bias {
let bias_data = bias.data()?;
for b in 0..batch {
for c in 0..self.out_channels {
let bval = bias_data[c];
for hw in 0..(h_out * w_out) {
output[b * self.out_channels * h_out * w_out
+ c * h_out * w_out
+ hw] += bval;
}
}
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, self.out_channels, h_out, w_out],
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let grad_fn = Arc::new(ConvTranspose2dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
_output_padding: self.output_padding,
h_out,
w_out,
});
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)
} else {
Ok(result)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.weight];
if let Some(ref b) = self.bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = vec![&mut self.weight];
if let Some(ref mut b) = self.bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![("weight".to_string(), &self.weight)];
if let Some(ref b) = self.bias {
params.push(("bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
struct ConvTranspose2dBackward<T: Float> {
input: Tensor<T>,
weight: Tensor<T>,
bias: Option<Tensor<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
_output_padding: (usize, usize),
h_out: usize,
w_out: usize,
}
impl<T: Float> GradFn<T> for ConvTranspose2dBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data()?;
let batch = self.input.shape()[0];
let h_in = self.input.shape()[2];
let w_in = self.input.shape()[3];
let (kh, kw) = self.kernel_size;
let (sh, sw) = self.stride;
let (ph, pw) = self.padding;
let grad_input = if self.input.requires_grad() {
let weight_data = self.weight.data()?;
let col_rows = self.out_channels * kh * kw;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data.to_vec()),
vec![self.in_channels, col_rows],
false,
)?;
let (go_cols, _go_col_rows, go_col_cols) = im2col(
go_data,
batch,
self.out_channels,
self.h_out,
self.w_out,
kh,
kw,
sh,
sw,
ph,
pw,
);
let zero = <T as num_traits::Zero>::zero();
let mut gi = vec![zero; batch * self.in_channels * h_in * w_in];
for b in 0..batch {
let col_start = b * col_rows * go_col_cols;
let col_end = col_start + col_rows * go_col_cols;
let go_cols_b = Tensor::from_storage(
TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
vec![col_rows, go_col_cols],
false,
)?;
let gi_b = mm(&weight_2d, &go_cols_b)?;
let gi_data = gi_b.data()?;
let out_start = b * self.in_channels * h_in * w_in;
let copy_len = self.in_channels * h_in * w_in;
gi[out_start..out_start + copy_len]
.copy_from_slice(&gi_data[..copy_len]);
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
let grad_weight = if self.weight.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let weight_numel = self.in_channels * self.out_channels * kh * kw;
let mut gw = vec![zero; weight_numel];
let input_data = self.input.data()?;
for b in 0..batch {
for ci in 0..self.in_channels {
for co in 0..self.out_channels {
for dh in 0..kh {
for dw in 0..kw {
let mut acc = zero;
for ih in 0..h_in {
for iw in 0..w_in {
let oh = ih * sh + dh;
let ow = iw * sw + dw;
if oh >= ph
&& ow >= pw
&& (oh - ph) < self.h_out
&& (ow - pw) < self.w_out
{
let go_idx = b * self.out_channels * self.h_out * self.w_out
+ co * self.h_out * self.w_out
+ (oh - ph) * self.w_out
+ (ow - pw);
let in_idx = b * self.in_channels * h_in * w_in
+ ci * h_in * w_in
+ ih * w_in
+ iw;
acc += input_data[in_idx] * go_data[go_idx];
}
}
}
gw[ci * self.out_channels * kh * kw
+ co * kh * kw
+ dh * kw
+ dw] += acc;
}
}
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gw),
vec![self.in_channels, self.out_channels, kh, kw],
false,
)?)
} else {
None
};
let grad_bias = match &self.bias {
Some(b) if b.requires_grad() => {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; self.out_channels];
for batch_idx in 0..batch {
for c in 0..self.out_channels {
for hw in 0..(self.h_out * self.w_out) {
gb[c] += go_data[batch_idx * self.out_channels * self.h_out * self.w_out
+ c * self.h_out * self.w_out
+ hw];
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gb),
vec![self.out_channels],
false,
)?)
}
_ => None,
};
Ok(vec![grad_input, grad_weight, grad_bias])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"ConvTranspose2dBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::module::Module;
fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: actual={a} expected={e} (diff {})",
(a - e).abs()
);
}
}
#[test]
fn test_output_shape_no_padding() {
let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
let input = t(&vec![0.0; 25], &[1, 1, 5, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 3, 3]);
}
#[test]
fn test_output_shape_with_padding() {
let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (1, 1), true).unwrap();
let input = t(&vec![0.0; 2 * 3 * 8 * 8], &[2, 3, 8, 8]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 16, 8, 8]);
}
#[test]
fn test_output_shape_with_stride() {
let conv = Conv2d::<f32>::new(1, 4, (3, 3), (2, 2), (0, 0), false).unwrap();
let input = t(&vec![0.0; 36], &[1, 1, 6, 6]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4, 2, 2]);
}
#[test]
fn test_1x1_conv_equals_linear() {
let weight_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
let input_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let weight_param =
Parameter::from_slice(&weight_data, &[3, 2, 1, 1]).unwrap();
let conv = Conv2d {
weight: weight_param,
bias: None,
in_channels: 2,
out_channels: 3,
kernel_size: (1, 1),
stride: (1, 1),
padding: (0, 0),
training: false,
};
let input = t(&input_data, &[1, 2, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 3, 2, 2]);
let out = output.data().unwrap();
let expected = [
11.0, 14.0, 17.0, 20.0, 23.0, 30.0, 37.0, 44.0, 35.0, 46.0, 57.0, 68.0, ];
assert_close(out, &expected, 1e-5);
}
#[test]
fn test_bias_addition() {
let weight_data = vec![1.0f32]; let bias_data = vec![10.0f32];
let conv = Conv2d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 1, 1]).unwrap(),
bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: (1, 1),
stride: (1, 1),
padding: (0, 0),
training: false,
};
let input = t(&[2.0, 3.0, 4.0, 5.0], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_close(output.data().unwrap(), &[12.0, 13.0, 14.0, 15.0], 1e-5);
}
#[test]
fn test_backward_produces_correct_shapes() {
let weight_data = vec![1.0f32; 2 * 1 * 3 * 3]; let input_data = vec![1.0f32; 1 * 1 * 5 * 5]; let bias_data = vec![0.0f32; 2];
let weight_param = Parameter::from_slice(&weight_data, &[2, 1, 3, 3]).unwrap();
let bias_param = Parameter::from_slice(&bias_data, &[2]).unwrap();
let conv = Conv2d {
weight: weight_param,
bias: Some(bias_param),
in_channels: 1,
out_channels: 2,
kernel_size: (3, 3),
stride: (1, 1),
padding: (0, 0),
training: false,
};
let input = leaf(&input_data, &[1, 1, 5, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 3, 3]);
assert!(output.grad_fn().is_some());
assert_eq!(output.grad_fn().unwrap().name(), "Conv2dBackward");
let grad_output = t(&vec![1.0; 2 * 3 * 3], &[1, 2, 3, 3]);
let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
assert!(grads[0].is_some());
assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 5, 5]);
assert!(grads[1].is_some());
assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 3, 3]);
assert!(grads[2].is_some());
assert_eq!(grads[2].as_ref().unwrap().shape(), &[2]);
}
#[test]
fn test_parameter_count_with_bias() {
let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (0, 0), true).unwrap();
assert_eq!(conv.num_parameters(), 448);
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_parameter_count_without_bias() {
let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (0, 0), false).unwrap();
assert_eq!(conv.num_parameters(), 432);
assert_eq!(conv.parameters().len(), 1);
}
#[test]
fn test_named_parameters() {
let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), true).unwrap();
let named = conv.named_parameters();
assert_eq!(named.len(), 2);
assert_eq!(named[0].0, "weight");
assert_eq!(named[1].0, "bias");
}
#[test]
fn test_train_eval() {
let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
assert!(conv.is_training());
conv.eval();
assert!(!conv.is_training());
conv.train();
assert!(conv.is_training());
}
#[test]
fn test_invalid_input_ndim() {
let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
let input = t(&[1.0, 2.0, 3.0], &[3]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_channel_mismatch() {
let conv = Conv2d::<f32>::new(3, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
let input = t(&vec![0.0; 1 * 1 * 5 * 5], &[1, 1, 5, 5]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_zero_channels_rejected() {
assert!(Conv2d::<f32>::new(0, 16, (3, 3), (1, 1), (0, 0), false).is_err());
assert!(Conv2d::<f32>::new(3, 0, (3, 3), (1, 1), (0, 0), false).is_err());
}
#[test]
fn test_zero_kernel_rejected() {
assert!(Conv2d::<f32>::new(1, 1, (0, 3), (1, 1), (0, 0), false).is_err());
}
#[test]
fn test_zero_stride_rejected() {
assert!(Conv2d::<f32>::new(1, 1, (3, 3), (0, 1), (0, 0), false).is_err());
}
#[test]
fn test_im2col_basic() {
#[rustfmt::skip]
let input: Vec<f32> = vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
];
let (cols, rows, n_cols) = im2col(&input, 1, 1, 3, 3, 2, 2, 1, 1, 0, 0);
assert_eq!(rows, 4); assert_eq!(n_cols, 4);
assert_close(
&cols,
&[
1.0, 2.0, 4.0, 5.0, 2.0, 3.0, 5.0, 6.0, 4.0, 5.0, 7.0, 8.0, 5.0, 6.0, 8.0, 9.0, ],
1e-7,
);
}
#[test]
fn test_col2im_roundtrip_no_overlap() {
#[rustfmt::skip]
let input: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let (cols, _rows, _n_cols) = im2col(&input, 1, 1, 4, 4, 2, 2, 2, 2, 0, 0);
let recovered = col2im(&cols, 1, 1, 4, 4, 2, 2, 2, 2, 0, 0, 2, 2);
assert_close(&recovered, &input, 1e-7);
}
#[test]
fn test_3x3_conv_forward() {
#[rustfmt::skip]
let input_data: Vec<f32> = vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
];
#[rustfmt::skip]
let weight_data: Vec<f32> = vec![
1.0, 0.0, -1.0,
1.0, 0.0, -1.0,
1.0, 0.0, -1.0,
];
let conv = Conv2d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 3, 3]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: (3, 3),
stride: (1, 1),
padding: (0, 0),
training: false,
};
let input = t(&input_data, &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 1, 1]);
assert_close(output.data().unwrap(), &[-6.0], 1e-5);
}
#[test]
fn test_padding_preserves_spatial_size() {
let weight_data = vec![0.0f32; 9];
let mut weight_data_center = weight_data;
weight_data_center[4] = 1.0;
let conv = Conv2d {
weight: Parameter::from_slice(&weight_data_center, &[1, 1, 3, 3]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: (3, 3),
stride: (1, 1),
padding: (1, 1),
training: false,
};
let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let input = t(&input_data, &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 3, 3]);
assert_close(output.data().unwrap(), &input_data, 1e-5);
}
#[test]
fn test_conv1d_output_shape_no_padding() {
let conv = Conv1d::<f32>::new(1, 4, 3, 1, 0, false).unwrap();
let input = t(&vec![0.0; 10], &[1, 1, 10]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4, 8]);
}
#[test]
fn test_conv1d_output_shape_with_padding() {
let conv = Conv1d::<f32>::new(3, 8, 3, 1, 1, true).unwrap();
let input = t(&vec![0.0; 2 * 3 * 16], &[2, 3, 16]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 8, 16]);
}
#[test]
fn test_conv1d_output_shape_with_stride() {
let conv = Conv1d::<f32>::new(1, 2, 3, 2, 0, false).unwrap();
let input = t(&vec![0.0; 10], &[1, 1, 10]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 4]);
}
#[test]
fn test_conv1d_1x1_kernel_correctness() {
let weight_data = vec![3.0f32, 5.0];
let conv = Conv1d {
weight: Parameter::from_slice(&weight_data, &[2, 1, 1]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 2,
kernel_size: 1,
stride: 1,
padding: 0,
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 4]);
assert_close(
output.data().unwrap(),
&[3.0, 6.0, 9.0, 12.0, 5.0, 10.0, 15.0, 20.0],
1e-5,
);
}
#[test]
fn test_conv1d_3_kernel_forward() {
let conv = Conv1d {
weight: Parameter::from_slice(&[1.0f32, 0.0, -1.0], &[1, 1, 3]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: 3,
stride: 1,
padding: 0,
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 3]);
assert_close(output.data().unwrap(), &[-2.0, -2.0, -2.0], 1e-5);
}
#[test]
fn test_conv1d_bias() {
let conv = Conv1d {
weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1]).unwrap(),
bias: Some(Parameter::from_slice(&[10.0f32], &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: 1,
stride: 1,
padding: 0,
training: false,
};
let input = t(&[2.0, 3.0, 4.0], &[1, 1, 3]);
let output = conv.forward(&input).unwrap();
assert_close(output.data().unwrap(), &[12.0, 13.0, 14.0], 1e-5);
}
#[test]
fn test_conv1d_invalid_ndim() {
let conv = Conv1d::<f32>::new(1, 1, 3, 1, 0, false).unwrap();
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv1d_channel_mismatch() {
let conv = Conv1d::<f32>::new(3, 1, 3, 1, 0, false).unwrap();
let input = t(&vec![0.0; 10], &[1, 1, 10]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv1d_zero_channels_rejected() {
assert!(Conv1d::<f32>::new(0, 4, 3, 1, 0, false).is_err());
assert!(Conv1d::<f32>::new(1, 0, 3, 1, 0, false).is_err());
}
#[test]
fn test_conv1d_zero_kernel_rejected() {
assert!(Conv1d::<f32>::new(1, 1, 0, 1, 0, false).is_err());
}
#[test]
fn test_conv1d_zero_stride_rejected() {
assert!(Conv1d::<f32>::new(1, 1, 3, 0, 0, false).is_err());
}
#[test]
fn test_conv1d_parameter_count() {
let conv = Conv1d::<f32>::new(3, 8, 5, 1, 0, true).unwrap();
assert_eq!(conv.num_parameters(), 128);
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_conv_transpose2d_output_shape_basic() {
let conv = ConvTranspose2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
let input = t(&vec![0.0; 9], &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5, 5]);
}
#[test]
fn test_conv_transpose2d_output_shape_stride2() {
let conv = ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), (0, 0), false).unwrap();
let input = t(&vec![0.0; 4], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5, 5]);
}
#[test]
fn test_conv_transpose2d_output_shape_with_padding() {
let conv = ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (1, 1), (0, 0), false).unwrap();
let input = t(&vec![0.0; 9], &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5, 5]);
}
#[test]
fn test_conv_transpose2d_output_shape_with_output_padding() {
let conv = ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false).unwrap();
let input = t(&vec![0.0; 9], &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 6, 6]);
}
#[test]
fn test_conv_transpose2d_stride2_upsamples() {
let conv = ConvTranspose2d::<f32>::new(1, 1, (2, 2), (2, 2), (0, 0), (0, 0), false).unwrap();
let input = t(&vec![0.0; 1 * 1 * 4 * 4], &[1, 1, 4, 4]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 8, 8]);
}
#[test]
fn test_conv_transpose2d_stride2_upsamples_multichannel() {
let conv = ConvTranspose2d::<f32>::new(8, 16, (2, 2), (2, 2), (0, 0), (0, 0), true).unwrap();
let input = t(&vec![0.0; 2 * 8 * 4 * 4], &[2, 8, 4, 4]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 16, 8, 8]);
}
#[test]
fn test_conv_transpose2d_1x1_kernel() {
let weight_data = vec![3.0f32, 7.0]; let conv = ConvTranspose2d {
weight: Parameter::from_slice(&weight_data, &[1, 2, 1, 1]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 2,
kernel_size: (1, 1),
stride: (1, 1),
padding: (0, 0),
output_padding: (0, 0),
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 2, 2]);
assert_close(
output.data().unwrap(),
&[3.0, 6.0, 9.0, 12.0, 7.0, 14.0, 21.0, 28.0],
1e-5,
);
}
#[test]
fn test_conv_transpose2d_stride2_correctness() {
let weight_data = vec![1.0f32; 4]; let conv = ConvTranspose2d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 2, 2]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: (2, 2),
stride: (2, 2),
padding: (0, 0),
output_padding: (0, 0),
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 4, 4]);
#[rustfmt::skip]
let expected = [
1.0, 1.0, 2.0, 2.0,
1.0, 1.0, 2.0, 2.0,
3.0, 3.0, 4.0, 4.0,
3.0, 3.0, 4.0, 4.0,
];
assert_close(output.data().unwrap(), &expected, 1e-5);
}
#[test]
fn test_conv_transpose2d_bias() {
let weight_data = vec![1.0f32]; let bias_data = vec![5.0f32];
let conv = ConvTranspose2d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 1, 1]).unwrap(),
bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: (1, 1),
stride: (1, 1),
padding: (0, 0),
output_padding: (0, 0),
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_close(output.data().unwrap(), &[6.0, 7.0, 8.0, 9.0], 1e-5);
}
#[test]
fn test_conv_transpose2d_invalid_ndim() {
let conv = ConvTranspose2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv_transpose2d_channel_mismatch() {
let conv = ConvTranspose2d::<f32>::new(3, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
let input = t(&vec![0.0; 1 * 1 * 5 * 5], &[1, 1, 5, 5]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv_transpose2d_zero_channels_rejected() {
assert!(ConvTranspose2d::<f32>::new(0, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).is_err());
assert!(ConvTranspose2d::<f32>::new(1, 0, (3, 3), (1, 1), (0, 0), (0, 0), false).is_err());
}
#[test]
fn test_conv_transpose2d_output_padding_too_large() {
assert!(ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), (2, 2), false).is_err());
}
#[test]
fn test_conv_transpose2d_parameter_count() {
let conv = ConvTranspose2d::<f32>::new(8, 16, (3, 3), (2, 2), (1, 1), (0, 0), true).unwrap();
assert_eq!(conv.num_parameters(), 1168);
assert_eq!(conv.parameters().len(), 2);
}
}