use std::sync::Arc;
use ferrotorch_core::autograd::autocast_ops::autocast_guard;
use ferrotorch_core::autograd::no_grad::is_grad_enabled;
use ferrotorch_core::ops::linalg::{mm, transpose};
use ferrotorch_core::storage::TensorStorage;
use ferrotorch_core::tensor::{GradFn, Tensor};
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
use crate::init::{NonLinearity, kaiming_uniform, zeros as zeros_init};
use crate::module::Module;
use crate::parameter::Parameter;
#[allow(clippy::too_many_arguments)]
fn im2col<T: Float>(
input: &[T],
batch: usize,
channels: usize,
height: usize,
width: usize,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
) -> (Vec<T>, usize, usize) {
let h_out = (height + 2 * pad_h - kernel_h) / stride_h + 1;
let w_out = (width + 2 * pad_w - kernel_w) / stride_w + 1;
let col_rows = channels * kernel_h * kernel_w;
let col_cols = h_out * w_out;
let zero = <T as num_traits::Zero>::zero();
let mut cols = vec![zero; batch * col_rows * col_cols];
for b in 0..batch {
for c in 0..channels {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let row = c * kernel_h * kernel_w + kh * kernel_w + kw;
for oh in 0..h_out {
for ow in 0..w_out {
let ih = oh * stride_h + kh;
let iw = ow * stride_w + kw;
let col = oh * w_out + ow;
let val = if ih >= pad_h
&& iw >= pad_w
&& (ih - pad_h) < height
&& (iw - pad_w) < width
{
let real_h = ih - pad_h;
let real_w = iw - pad_w;
input[b * channels * height * width
+ c * height * width
+ real_h * width
+ real_w]
} else {
zero
};
cols[b * col_rows * col_cols + row * col_cols + col] = val;
}
}
}
}
}
}
(cols, col_rows, col_cols)
}
#[allow(clippy::too_many_arguments)]
fn col2im<T: Float>(
cols: &[T],
batch: usize,
channels: usize,
height: usize,
width: usize,
kernel_h: usize,
kernel_w: usize,
stride_h: usize,
stride_w: usize,
pad_h: usize,
pad_w: usize,
h_out: usize,
w_out: usize,
) -> Vec<T> {
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * channels * height * width];
let col_rows = channels * kernel_h * kernel_w;
let col_cols = h_out * w_out;
for b in 0..batch {
for c in 0..channels {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let row = c * kernel_h * kernel_w + kh * kernel_w + kw;
for oh in 0..h_out {
for ow in 0..w_out {
let ih = oh * stride_h + kh;
let iw = ow * stride_w + kw;
let col = oh * w_out + ow;
if ih >= pad_h
&& iw >= pad_w
&& (ih - pad_h) < height
&& (iw - pad_w) < width
{
let real_h = ih - pad_h;
let real_w = iw - pad_w;
output[b * channels * height * width
+ c * height * width
+ real_h * width
+ real_w] +=
cols[b * col_rows * col_cols + row * col_cols + col];
}
}
}
}
}
}
}
output
}
#[derive(Debug)]
pub struct Conv2d<T: Float> {
weight: Parameter<T>,
bias: Option<Parameter<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
training: bool,
}
impl<T: Float> Conv2d<T> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
bias: bool,
) -> FerrotorchResult<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "in_channels and out_channels must be > 0".into(),
});
}
if kernel_size.0 == 0 || kernel_size.1 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "kernel_size must be > 0 in both dimensions".into(),
});
}
if stride.0 == 0 || stride.1 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "stride must be > 0 in both dimensions".into(),
});
}
let (kh, kw) = kernel_size;
let mut weight = Parameter::zeros(&[out_channels, in_channels, kh, kw])?;
kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
let bias_param = if bias {
let mut b = Parameter::zeros(&[out_channels])?;
zeros_init(&mut b)?;
Some(b)
} else {
None
};
Ok(Self {
weight,
bias: bias_param,
in_channels,
out_channels,
kernel_size,
stride,
padding,
training: true,
})
}
pub fn num_parameters(&self) -> usize {
let w = self.out_channels * self.in_channels * self.kernel_size.0 * self.kernel_size.1;
let b = if self.bias.is_some() {
self.out_channels
} else {
0
};
w + b
}
}
impl<T: Float> Module<T> for Conv2d<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let _autocast_cat = autocast_guard("conv2d");
if input.ndim() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv2d expects 4-D input [B, C, H, W], got {:?}",
input.shape()
),
});
}
let batch = input.shape()[0];
let c_in = input.shape()[1];
let h = input.shape()[2];
let w = input.shape()[3];
if c_in != self.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"Conv2d: expected {} input channels, got {}",
self.in_channels, c_in
),
});
}
let (kh, kw) = self.kernel_size;
let (sh, sw) = self.stride;
let (ph, pw) = self.padding;
let h_padded = h + 2 * ph;
let w_padded = w + 2 * pw;
if h_padded < kh || w_padded < kw {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv2d: padded input ({h_padded}, {w_padded}) is smaller than kernel ({kh}, {kw})"
),
});
}
let h_out = (h_padded - kh) / sh + 1;
let w_out = (w_padded - kw) / sw + 1;
let input_device = input.device();
let is_f32 = std::mem::size_of::<T>() == 4;
if is_f32 && input.is_cuda() {
if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
let bias_handle = self.bias.as_ref().and_then(|b| b.tensor().gpu_handle().ok());
let (out_handle, out_shape) = backend.conv2d_f32(
input.gpu_handle()?,
self.weight.tensor().gpu_handle()?,
bias_handle,
[batch, c_in, h, w],
[self.out_channels, self.in_channels, kh, kw],
self.stride,
self.padding,
)?;
let result = Tensor::from_storage(
TensorStorage::gpu(out_handle),
out_shape.to_vec(),
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let input_data = input.data_vec()?;
let (cols, col_rows, col_cols) =
im2col(&input_data, batch, c_in, h, w, kh, kw, sh, sw, ph, pw);
let grad_fn = Arc::new(Conv2dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
cols,
col_rows,
col_cols,
h_out,
w_out,
});
return Tensor::from_operation(
result.into_storage_and_shape()?.0,
out_shape.to_vec(),
grad_fn,
);
}
return Ok(result);
}
}
let input_data = input.data_vec()?;
let (cols, col_rows, col_cols) =
im2col(&input_data, batch, c_in, h, w, kh, kw, sh, sw, ph, pw);
let weight_data = self.weight.data_vec()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![self.out_channels, col_rows],
false,
)?;
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * self.out_channels * h_out * w_out];
for b in 0..batch {
let col_start = b * col_rows * col_cols;
let col_end = col_start + col_rows * col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(cols[col_start..col_end].to_vec()),
vec![col_rows, col_cols],
false,
)?;
let out_b = mm(&weight_2d, &cols_b)?;
let out_data = out_b.data()?;
let out_start = b * self.out_channels * h_out * w_out;
output[out_start..out_start + self.out_channels * h_out * w_out]
.copy_from_slice(out_data);
}
if let Some(ref bias) = self.bias {
let bias_data = bias.data_vec()?;
for b in 0..batch {
for c in 0..self.out_channels {
let bval = bias_data[c];
for hw in 0..(h_out * w_out) {
output[b * self.out_channels * h_out * w_out + c * h_out * w_out + hw] +=
bval;
}
}
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, self.out_channels, h_out, w_out],
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let grad_fn = Arc::new(Conv2dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
cols,
col_rows,
col_cols,
h_out,
w_out,
});
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)?
.to(input_device) } else {
result.to(input_device)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.weight];
if let Some(ref b) = self.bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = vec![&mut self.weight];
if let Some(ref mut b) = self.bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![("weight".to_string(), &self.weight)];
if let Some(ref b) = self.bias {
params.push(("bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
struct Conv2dBackward<T: Float> {
input: Tensor<T>,
weight: Tensor<T>,
bias: Option<Tensor<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
cols: Vec<T>,
col_rows: usize,
col_cols: usize,
h_out: usize,
w_out: usize,
}
impl<T: Float> GradFn<T> for Conv2dBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data_vec()?;
let batch = self.input.shape()[0];
let h = self.input.shape()[2];
let w = self.input.shape()[3];
let (kh, kw) = self.kernel_size;
let (sh, sw) = self.stride;
let (ph, pw) = self.padding;
let grad_weight = if self.weight.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let weight_numel = self.out_channels * self.col_rows;
let mut gw_accum = vec![zero; weight_numel];
for b in 0..batch {
let go_start = b * self.out_channels * self.h_out * self.w_out;
let go_end = go_start + self.out_channels * self.h_out * self.w_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, self.h_out * self.w_out],
false,
)?;
let col_start = b * self.col_rows * self.col_cols;
let col_end = col_start + self.col_rows * self.col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(self.cols[col_start..col_end].to_vec()),
vec![self.col_rows, self.col_cols],
false,
)?;
let cols_bt = transpose(&cols_b)?;
let gw_b = mm(&go_b, &cols_bt)?;
let gw_data = gw_b.data()?;
for i in 0..weight_numel {
gw_accum[i] += gw_data[i];
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gw_accum),
vec![self.out_channels, self.in_channels, kh, kw],
false,
)?)
} else {
None
};
let grad_bias = match &self.bias {
Some(b) if b.requires_grad() => {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; self.out_channels];
for batch_idx in 0..batch {
for c in 0..self.out_channels {
for hw in 0..(self.h_out * self.w_out) {
gb[c] +=
go_data[batch_idx * self.out_channels * self.h_out * self.w_out
+ c * self.h_out * self.w_out
+ hw];
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gb),
vec![self.out_channels],
false,
)?)
}
_ => None,
};
let grad_input = if self.input.requires_grad() {
let weight_data = self.weight.data_vec()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![self.out_channels, self.col_rows],
false,
)?;
let weight_2d_t = transpose(&weight_2d)?;
let zero = <T as num_traits::Zero>::zero();
let mut grad_cols = vec![zero; batch * self.col_rows * self.col_cols];
for b in 0..batch {
let go_start = b * self.out_channels * self.h_out * self.w_out;
let go_end = go_start + self.out_channels * self.h_out * self.w_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, self.h_out * self.w_out],
false,
)?;
let gc_b = mm(&weight_2d_t, &go_b)?;
let gc_data = gc_b.data()?;
let gc_start = b * self.col_rows * self.col_cols;
grad_cols[gc_start..gc_start + self.col_rows * self.col_cols]
.copy_from_slice(gc_data);
}
let gi = col2im(
&grad_cols,
batch,
self.in_channels,
h,
w,
kh,
kw,
sh,
sw,
ph,
pw,
self.h_out,
self.w_out,
);
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
let mut grads = vec![grad_input, grad_weight];
if self.bias.is_some() {
grads.push(grad_bias);
}
Ok(grads)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"Conv2dBackward"
}
}
#[derive(Debug)]
pub struct Conv1d<T: Float> {
weight: Parameter<T>,
bias: Option<Parameter<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
training: bool,
}
impl<T: Float> Conv1d<T> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
bias: bool,
) -> FerrotorchResult<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "in_channels and out_channels must be > 0".into(),
});
}
if kernel_size == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "kernel_size must be > 0".into(),
});
}
if stride == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "stride must be > 0".into(),
});
}
let mut weight = Parameter::zeros(&[out_channels, in_channels, kernel_size])?;
kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
let bias_param = if bias {
let mut b = Parameter::zeros(&[out_channels])?;
zeros_init(&mut b)?;
Some(b)
} else {
None
};
Ok(Self {
weight,
bias: bias_param,
in_channels,
out_channels,
kernel_size,
stride,
padding,
training: true,
})
}
pub fn num_parameters(&self) -> usize {
let w = self.out_channels * self.in_channels * self.kernel_size;
let b = if self.bias.is_some() {
self.out_channels
} else {
0
};
w + b
}
}
impl<T: Float> Module<T> for Conv1d<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let _autocast_cat = autocast_guard("conv1d");
if input.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv1d expects 3-D input [B, C, L], got {:?}",
input.shape()
),
});
}
let batch = input.shape()[0];
let c_in = input.shape()[1];
let length = input.shape()[2];
if c_in != self.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"Conv1d: expected {} input channels, got {}",
self.in_channels, c_in
),
});
}
let k = self.kernel_size;
let s = self.stride;
let p = self.padding;
let l_padded = length + 2 * p;
if l_padded < k {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv1d: padded input length ({l_padded}) is smaller than kernel ({k})"
),
});
}
let l_out = (l_padded - k) / s + 1;
let input_device = input.device();
let input_data = input.data_vec()?;
let (cols, col_rows, col_cols) =
im2col(&input_data, batch, c_in, 1, length, 1, k, 1, s, 0, p);
let weight_data = self.weight.data_vec()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![self.out_channels, col_rows],
false,
)?;
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * self.out_channels * l_out];
for b in 0..batch {
let col_start = b * col_rows * col_cols;
let col_end = col_start + col_rows * col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(cols[col_start..col_end].to_vec()),
vec![col_rows, col_cols],
false,
)?;
let out_b = mm(&weight_2d, &cols_b)?;
let out_data = out_b.data()?;
let out_start = b * self.out_channels * l_out;
output[out_start..out_start + self.out_channels * l_out].copy_from_slice(out_data);
}
if let Some(ref bias) = self.bias {
let bias_data = bias.data_vec()?;
for b in 0..batch {
for c in 0..self.out_channels {
let bval = bias_data[c];
for l in 0..l_out {
output[b * self.out_channels * l_out + c * l_out + l] += bval;
}
}
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, self.out_channels, l_out],
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let grad_fn = Arc::new(Conv1dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
cols,
col_rows,
col_cols,
l_out,
});
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)?
.to(input_device) } else {
result.to(input_device)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.weight];
if let Some(ref b) = self.bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = vec![&mut self.weight];
if let Some(ref mut b) = self.bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![("weight".to_string(), &self.weight)];
if let Some(ref b) = self.bias {
params.push(("bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
struct Conv1dBackward<T: Float> {
input: Tensor<T>,
weight: Tensor<T>,
bias: Option<Tensor<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
cols: Vec<T>,
col_rows: usize,
col_cols: usize,
l_out: usize,
}
impl<T: Float> GradFn<T> for Conv1dBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data_vec()?;
let batch = self.input.shape()[0];
let length = self.input.shape()[2];
let k = self.kernel_size;
let s = self.stride;
let p = self.padding;
let grad_weight = if self.weight.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let weight_numel = self.out_channels * self.col_rows;
let mut gw_accum = vec![zero; weight_numel];
for b in 0..batch {
let go_start = b * self.out_channels * self.l_out;
let go_end = go_start + self.out_channels * self.l_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, self.l_out],
false,
)?;
let col_start = b * self.col_rows * self.col_cols;
let col_end = col_start + self.col_rows * self.col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(self.cols[col_start..col_end].to_vec()),
vec![self.col_rows, self.col_cols],
false,
)?;
let cols_bt = transpose(&cols_b)?;
let gw_b = mm(&go_b, &cols_bt)?;
let gw_data = gw_b.data()?;
for i in 0..weight_numel {
gw_accum[i] += gw_data[i];
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gw_accum),
vec![self.out_channels, self.in_channels, k],
false,
)?)
} else {
None
};
let grad_bias = match &self.bias {
Some(b) if b.requires_grad() => {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; self.out_channels];
for batch_idx in 0..batch {
for c in 0..self.out_channels {
for l in 0..self.l_out {
gb[c] += go_data
[batch_idx * self.out_channels * self.l_out + c * self.l_out + l];
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gb),
vec![self.out_channels],
false,
)?)
}
_ => None,
};
let grad_input = if self.input.requires_grad() {
let weight_data = self.weight.data_vec()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![self.out_channels, self.col_rows],
false,
)?;
let weight_2d_t = transpose(&weight_2d)?;
let zero = <T as num_traits::Zero>::zero();
let mut grad_cols = vec![zero; batch * self.col_rows * self.col_cols];
for b in 0..batch {
let go_start = b * self.out_channels * self.l_out;
let go_end = go_start + self.out_channels * self.l_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, self.l_out],
false,
)?;
let gc_b = mm(&weight_2d_t, &go_b)?;
let gc_data = gc_b.data()?;
let gc_start = b * self.col_rows * self.col_cols;
grad_cols[gc_start..gc_start + self.col_rows * self.col_cols]
.copy_from_slice(gc_data);
}
let gi = col2im(
&grad_cols,
batch,
self.in_channels,
1,
length,
1,
k,
1,
s,
0,
p,
1,
self.l_out,
);
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
let mut grads = vec![grad_input, grad_weight];
if self.bias.is_some() {
grads.push(grad_bias);
}
Ok(grads)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"Conv1dBackward"
}
}
#[derive(Debug)]
pub struct ConvTranspose2d<T: Float> {
weight: Parameter<T>,
bias: Option<Parameter<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
output_padding: (usize, usize),
training: bool,
}
impl<T: Float> ConvTranspose2d<T> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
output_padding: (usize, usize),
bias: bool,
) -> FerrotorchResult<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "in_channels and out_channels must be > 0".into(),
});
}
if kernel_size.0 == 0 || kernel_size.1 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "kernel_size must be > 0 in both dimensions".into(),
});
}
if stride.0 == 0 || stride.1 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "stride must be > 0 in both dimensions".into(),
});
}
if output_padding.0 >= stride.0 || output_padding.1 >= stride.1 {
return Err(FerrotorchError::InvalidArgument {
message: "output_padding must be strictly less than stride".into(),
});
}
let (kh, kw) = kernel_size;
let mut weight = Parameter::zeros(&[in_channels, out_channels, kh, kw])?;
kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
let bias_param = if bias {
let mut b = Parameter::zeros(&[out_channels])?;
zeros_init(&mut b)?;
Some(b)
} else {
None
};
Ok(Self {
weight,
bias: bias_param,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
training: true,
})
}
pub fn num_parameters(&self) -> usize {
let w = self.in_channels * self.out_channels * self.kernel_size.0 * self.kernel_size.1;
let b = if self.bias.is_some() {
self.out_channels
} else {
0
};
w + b
}
}
fn stride_insert_zeros<T: Float>(
input: &[T],
batch: usize,
channels: usize,
h: usize,
w: usize,
stride_h: usize,
stride_w: usize,
) -> (Vec<T>, usize, usize) {
let h_up = (h - 1) * stride_h + 1;
let w_up = (w - 1) * stride_w + 1;
let zero = <T as num_traits::Zero>::zero();
let mut out = vec![zero; batch * channels * h_up * w_up];
for b in 0..batch {
for c in 0..channels {
for ih in 0..h {
for iw in 0..w {
let oh = ih * stride_h;
let ow = iw * stride_w;
out[b * channels * h_up * w_up + c * h_up * w_up + oh * w_up + ow] =
input[b * channels * h * w + c * h * w + ih * w + iw];
}
}
}
}
(out, h_up, w_up)
}
fn flip_kernel<T: Float>(kernel: &[T], c_in: usize, c_out: usize, kh: usize, kw: usize) -> Vec<T> {
let zero = <T as num_traits::Zero>::zero();
let mut flipped = vec![zero; c_out * c_in * kh * kw];
for ci in 0..c_in {
for co in 0..c_out {
for h in 0..kh {
for w in 0..kw {
let src = ci * c_out * kh * kw + co * kh * kw + h * kw + w;
let dst = co * c_in * kh * kw + ci * kh * kw + (kh - 1 - h) * kw + (kw - 1 - w);
flipped[dst] = kernel[src];
}
}
}
}
flipped
}
impl<T: Float> Module<T> for ConvTranspose2d<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let _autocast_cat = autocast_guard("conv_transpose2d");
if input.ndim() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ConvTranspose2d expects 4-D input [B, C, H, W], got {:?}",
input.shape()
),
});
}
let batch = input.shape()[0];
let c_in = input.shape()[1];
let h = input.shape()[2];
let w = input.shape()[3];
if c_in != self.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"ConvTranspose2d: expected {} input channels, got {}",
self.in_channels, c_in
),
});
}
let (kh, kw) = self.kernel_size;
let (sh, sw) = self.stride;
let (ph, pw) = self.padding;
let (oph, opw) = self.output_padding;
let input_device = input.device();
let input_data = input.data_vec()?;
let (upsampled, h_up, w_up) = stride_insert_zeros(&input_data, batch, c_in, h, w, sh, sw);
let weight_data = self.weight.data_vec()?;
let flipped = flip_kernel(&weight_data, self.in_channels, self.out_channels, kh, kw);
let internal_pad_h = kh - 1 - ph;
let internal_pad_w = kw - 1 - pw;
let (cols, col_rows, col_cols) = im2col(
&upsampled,
batch,
c_in,
h_up,
w_up,
kh,
kw,
1,
1,
internal_pad_h,
internal_pad_w,
);
let h_out_base = (h_up + 2 * internal_pad_h - kh) + 1;
let w_out_base = (w_up + 2 * internal_pad_w - kw) + 1;
let h_out = h_out_base + oph;
let w_out = w_out_base + opw;
let flipped_2d = Tensor::from_storage(
TensorStorage::cpu(flipped),
vec![self.out_channels, col_rows],
false,
)?;
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * self.out_channels * h_out * w_out];
for b in 0..batch {
let col_start = b * col_rows * col_cols;
let col_end = col_start + col_rows * col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(cols[col_start..col_end].to_vec()),
vec![col_rows, col_cols],
false,
)?;
let out_b = mm(&flipped_2d, &cols_b)?;
let out_data = out_b.data()?;
let out_start = b * self.out_channels * h_out * w_out;
for c in 0..self.out_channels {
for oh in 0..h_out_base {
for ow in 0..w_out_base {
output[out_start + c * h_out * w_out + oh * w_out + ow] =
out_data[c * h_out_base * w_out_base + oh * w_out_base + ow];
}
}
}
}
if let Some(ref bias) = self.bias {
let bias_data = bias.data_vec()?;
for b in 0..batch {
for c in 0..self.out_channels {
let bval = bias_data[c];
for hw in 0..(h_out * w_out) {
output[b * self.out_channels * h_out * w_out + c * h_out * w_out + hw] +=
bval;
}
}
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, self.out_channels, h_out, w_out],
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let grad_fn = Arc::new(ConvTranspose2dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
_output_padding: self.output_padding,
h_out,
w_out,
});
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)?
.to(input_device) } else {
result.to(input_device)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.weight];
if let Some(ref b) = self.bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = vec![&mut self.weight];
if let Some(ref mut b) = self.bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![("weight".to_string(), &self.weight)];
if let Some(ref b) = self.bias {
params.push(("bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
struct ConvTranspose2dBackward<T: Float> {
input: Tensor<T>,
weight: Tensor<T>,
bias: Option<Tensor<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
_output_padding: (usize, usize),
h_out: usize,
w_out: usize,
}
impl<T: Float> GradFn<T> for ConvTranspose2dBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data_vec()?;
let batch = self.input.shape()[0];
let h_in = self.input.shape()[2];
let w_in = self.input.shape()[3];
let (kh, kw) = self.kernel_size;
let (sh, sw) = self.stride;
let (ph, pw) = self.padding;
let grad_input = if self.input.requires_grad() {
let weight_data = self.weight.data_vec()?;
let col_rows = self.out_channels * kh * kw;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![self.in_channels, col_rows],
false,
)?;
let (go_cols, _go_col_rows, go_col_cols) = im2col(
&go_data,
batch,
self.out_channels,
self.h_out,
self.w_out,
kh,
kw,
sh,
sw,
ph,
pw,
);
let zero = <T as num_traits::Zero>::zero();
let mut gi = vec![zero; batch * self.in_channels * h_in * w_in];
for b in 0..batch {
let col_start = b * col_rows * go_col_cols;
let col_end = col_start + col_rows * go_col_cols;
let go_cols_b = Tensor::from_storage(
TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
vec![col_rows, go_col_cols],
false,
)?;
let gi_b = mm(&weight_2d, &go_cols_b)?;
let gi_data = gi_b.data()?;
let out_start = b * self.in_channels * h_in * w_in;
let copy_len = self.in_channels * h_in * w_in;
gi[out_start..out_start + copy_len].copy_from_slice(&gi_data[..copy_len]);
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
let grad_weight = if self.weight.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let weight_numel = self.in_channels * self.out_channels * kh * kw;
let mut gw = vec![zero; weight_numel];
let input_data = self.input.data_vec()?;
for b in 0..batch {
for ci in 0..self.in_channels {
for co in 0..self.out_channels {
for dh in 0..kh {
for dw in 0..kw {
let mut acc = zero;
for ih in 0..h_in {
for iw in 0..w_in {
let oh = ih * sh + dh;
let ow = iw * sw + dw;
if oh >= ph
&& ow >= pw
&& (oh - ph) < self.h_out
&& (ow - pw) < self.w_out
{
let go_idx =
b * self.out_channels * self.h_out * self.w_out
+ co * self.h_out * self.w_out
+ (oh - ph) * self.w_out
+ (ow - pw);
let in_idx = b * self.in_channels * h_in * w_in
+ ci * h_in * w_in
+ ih * w_in
+ iw;
acc += input_data[in_idx] * go_data[go_idx];
}
}
}
gw[ci * self.out_channels * kh * kw
+ co * kh * kw
+ dh * kw
+ dw] += acc;
}
}
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gw),
vec![self.in_channels, self.out_channels, kh, kw],
false,
)?)
} else {
None
};
let grad_bias = match &self.bias {
Some(b) if b.requires_grad() => {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; self.out_channels];
for batch_idx in 0..batch {
for c in 0..self.out_channels {
for hw in 0..(self.h_out * self.w_out) {
gb[c] +=
go_data[batch_idx * self.out_channels * self.h_out * self.w_out
+ c * self.h_out * self.w_out
+ hw];
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gb),
vec![self.out_channels],
false,
)?)
}
_ => None,
};
let mut grads = vec![grad_input, grad_weight];
if self.bias.is_some() {
grads.push(grad_bias);
}
Ok(grads)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"ConvTranspose2dBackward"
}
}
#[allow(clippy::too_many_arguments)]
fn im2col_3d<T: Float>(
input: &[T],
batch: usize,
channels: usize,
depth: usize,
height: usize,
width: usize,
kernel_d: usize,
kernel_h: usize,
kernel_w: usize,
stride_d: usize,
stride_h: usize,
stride_w: usize,
pad_d: usize,
pad_h: usize,
pad_w: usize,
) -> (Vec<T>, usize, usize) {
let d_out = (depth + 2 * pad_d - kernel_d) / stride_d + 1;
let h_out = (height + 2 * pad_h - kernel_h) / stride_h + 1;
let w_out = (width + 2 * pad_w - kernel_w) / stride_w + 1;
let col_rows = channels * kernel_d * kernel_h * kernel_w;
let col_cols = d_out * h_out * w_out;
let zero = <T as num_traits::Zero>::zero();
let mut cols = vec![zero; batch * col_rows * col_cols];
for b in 0..batch {
for c in 0..channels {
for kd in 0..kernel_d {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let row = c * kernel_d * kernel_h * kernel_w
+ kd * kernel_h * kernel_w
+ kh * kernel_w
+ kw;
for od in 0..d_out {
for oh in 0..h_out {
for ow in 0..w_out {
let id = od * stride_d + kd;
let ih = oh * stride_h + kh;
let iw = ow * stride_w + kw;
let col = od * h_out * w_out + oh * w_out + ow;
let val = if id >= pad_d
&& ih >= pad_h
&& iw >= pad_w
&& (id - pad_d) < depth
&& (ih - pad_h) < height
&& (iw - pad_w) < width
{
let real_d = id - pad_d;
let real_h = ih - pad_h;
let real_w = iw - pad_w;
input[b * channels * depth * height * width
+ c * depth * height * width
+ real_d * height * width
+ real_h * width
+ real_w]
} else {
zero
};
cols[b * col_rows * col_cols + row * col_cols + col] = val;
}
}
}
}
}
}
}
}
(cols, col_rows, col_cols)
}
#[allow(clippy::too_many_arguments)]
fn col2im_3d<T: Float>(
cols: &[T],
batch: usize,
channels: usize,
depth: usize,
height: usize,
width: usize,
kernel_d: usize,
kernel_h: usize,
kernel_w: usize,
stride_d: usize,
stride_h: usize,
stride_w: usize,
pad_d: usize,
pad_h: usize,
pad_w: usize,
d_out: usize,
h_out: usize,
w_out: usize,
) -> Vec<T> {
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * channels * depth * height * width];
let col_rows = channels * kernel_d * kernel_h * kernel_w;
let col_cols = d_out * h_out * w_out;
for b in 0..batch {
for c in 0..channels {
for kd in 0..kernel_d {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let row = c * kernel_d * kernel_h * kernel_w
+ kd * kernel_h * kernel_w
+ kh * kernel_w
+ kw;
for od in 0..d_out {
for oh in 0..h_out {
for ow in 0..w_out {
let id = od * stride_d + kd;
let ih = oh * stride_h + kh;
let iw = ow * stride_w + kw;
let col = od * h_out * w_out + oh * w_out + ow;
if id >= pad_d
&& ih >= pad_h
&& iw >= pad_w
&& (id - pad_d) < depth
&& (ih - pad_h) < height
&& (iw - pad_w) < width
{
let real_d = id - pad_d;
let real_h = ih - pad_h;
let real_w = iw - pad_w;
output[b * channels * depth * height * width
+ c * depth * height * width
+ real_d * height * width
+ real_h * width
+ real_w] +=
cols[b * col_rows * col_cols + row * col_cols + col];
}
}
}
}
}
}
}
}
}
output
}
#[derive(Debug)]
pub struct Conv3d<T: Float> {
weight: Parameter<T>,
bias: Option<Parameter<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize, usize),
stride: (usize, usize, usize),
padding: (usize, usize, usize),
training: bool,
}
impl<T: Float> Conv3d<T> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize, usize),
stride: (usize, usize, usize),
padding: (usize, usize, usize),
bias: bool,
) -> FerrotorchResult<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "in_channels and out_channels must be > 0".into(),
});
}
if kernel_size.0 == 0 || kernel_size.1 == 0 || kernel_size.2 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "kernel_size must be > 0 in all dimensions".into(),
});
}
if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "stride must be > 0 in all dimensions".into(),
});
}
let (kd, kh, kw) = kernel_size;
let mut weight = Parameter::zeros(&[out_channels, in_channels, kd, kh, kw])?;
kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
let bias_param = if bias {
let mut b = Parameter::zeros(&[out_channels])?;
zeros_init(&mut b)?;
Some(b)
} else {
None
};
Ok(Self {
weight,
bias: bias_param,
in_channels,
out_channels,
kernel_size,
stride,
padding,
training: true,
})
}
pub fn num_parameters(&self) -> usize {
let w = self.out_channels
* self.in_channels
* self.kernel_size.0
* self.kernel_size.1
* self.kernel_size.2;
let b = if self.bias.is_some() {
self.out_channels
} else {
0
};
w + b
}
}
impl<T: Float> Module<T> for Conv3d<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let _autocast_cat = autocast_guard("conv3d");
if input.ndim() != 5 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv3d expects 5-D input [B, C, D, H, W], got {:?}",
input.shape()
),
});
}
let batch = input.shape()[0];
let c_in = input.shape()[1];
let d = input.shape()[2];
let h = input.shape()[3];
let w = input.shape()[4];
if c_in != self.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"Conv3d: expected {} input channels, got {}",
self.in_channels, c_in
),
});
}
let (kd, kh, kw) = self.kernel_size;
let (sd, sh, sw) = self.stride;
let (pd, ph, pw) = self.padding;
let d_padded = d + 2 * pd;
let h_padded = h + 2 * ph;
let w_padded = w + 2 * pw;
if d_padded < kd || h_padded < kh || w_padded < kw {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"Conv3d: padded input ({d_padded}, {h_padded}, {w_padded}) is smaller than kernel ({kd}, {kh}, {kw})"
),
});
}
let d_out = (d_padded - kd) / sd + 1;
let h_out = (h_padded - kh) / sh + 1;
let w_out = (w_padded - kw) / sw + 1;
let input_device = input.device();
let input_data = input.data_vec()?;
let (cols, col_rows, col_cols) = im2col_3d(
&input_data, batch, c_in, d, h, w, kd, kh, kw, sd, sh, sw, pd, ph, pw,
);
let weight_data = self.weight.data_vec()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![self.out_channels, col_rows],
false,
)?;
let zero = <T as num_traits::Zero>::zero();
let spatial_out = d_out * h_out * w_out;
let mut output = vec![zero; batch * self.out_channels * spatial_out];
for b in 0..batch {
let col_start = b * col_rows * col_cols;
let col_end = col_start + col_rows * col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(cols[col_start..col_end].to_vec()),
vec![col_rows, col_cols],
false,
)?;
let out_b = mm(&weight_2d, &cols_b)?;
let out_data = out_b.data()?;
let out_start = b * self.out_channels * spatial_out;
output[out_start..out_start + self.out_channels * spatial_out]
.copy_from_slice(out_data);
}
if let Some(ref bias) = self.bias {
let bias_data = bias.data_vec()?;
for b in 0..batch {
for c in 0..self.out_channels {
let bval = bias_data[c];
for s in 0..spatial_out {
output[b * self.out_channels * spatial_out + c * spatial_out + s] += bval;
}
}
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, self.out_channels, d_out, h_out, w_out],
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let grad_fn = Arc::new(Conv3dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
cols,
col_rows,
col_cols,
d_out,
h_out,
w_out,
});
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)?
.to(input_device) } else {
result.to(input_device)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.weight];
if let Some(ref b) = self.bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = vec![&mut self.weight];
if let Some(ref mut b) = self.bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![("weight".to_string(), &self.weight)];
if let Some(ref b) = self.bias {
params.push(("bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
struct Conv3dBackward<T: Float> {
input: Tensor<T>,
weight: Tensor<T>,
bias: Option<Tensor<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize, usize),
stride: (usize, usize, usize),
padding: (usize, usize, usize),
cols: Vec<T>,
col_rows: usize,
col_cols: usize,
d_out: usize,
h_out: usize,
w_out: usize,
}
impl<T: Float> GradFn<T> for Conv3dBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data_vec()?;
let batch = self.input.shape()[0];
let d = self.input.shape()[2];
let h = self.input.shape()[3];
let w = self.input.shape()[4];
let (kd, kh, kw) = self.kernel_size;
let (sd, sh, sw) = self.stride;
let (pd, ph, pw) = self.padding;
let spatial_out = self.d_out * self.h_out * self.w_out;
let grad_weight = if self.weight.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let weight_numel = self.out_channels * self.col_rows;
let mut gw_accum = vec![zero; weight_numel];
for b in 0..batch {
let go_start = b * self.out_channels * spatial_out;
let go_end = go_start + self.out_channels * spatial_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, spatial_out],
false,
)?;
let col_start = b * self.col_rows * self.col_cols;
let col_end = col_start + self.col_rows * self.col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(self.cols[col_start..col_end].to_vec()),
vec![self.col_rows, self.col_cols],
false,
)?;
let cols_bt = transpose(&cols_b)?;
let gw_b = mm(&go_b, &cols_bt)?;
let gw_data = gw_b.data()?;
for i in 0..weight_numel {
gw_accum[i] += gw_data[i];
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gw_accum),
vec![self.out_channels, self.in_channels, kd, kh, kw],
false,
)?)
} else {
None
};
let grad_bias = match &self.bias {
Some(b) if b.requires_grad() => {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; self.out_channels];
for batch_idx in 0..batch {
for c in 0..self.out_channels {
for s in 0..spatial_out {
gb[c] += go_data
[batch_idx * self.out_channels * spatial_out + c * spatial_out + s];
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gb),
vec![self.out_channels],
false,
)?)
}
_ => None,
};
let grad_input = if self.input.requires_grad() {
let weight_data = self.weight.data_vec()?;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![self.out_channels, self.col_rows],
false,
)?;
let weight_2d_t = transpose(&weight_2d)?;
let zero = <T as num_traits::Zero>::zero();
let mut grad_cols = vec![zero; batch * self.col_rows * self.col_cols];
for b in 0..batch {
let go_start = b * self.out_channels * spatial_out;
let go_end = go_start + self.out_channels * spatial_out;
let go_b = Tensor::from_storage(
TensorStorage::cpu(go_data[go_start..go_end].to_vec()),
vec![self.out_channels, spatial_out],
false,
)?;
let gc_b = mm(&weight_2d_t, &go_b)?;
let gc_data = gc_b.data()?;
let gc_start = b * self.col_rows * self.col_cols;
grad_cols[gc_start..gc_start + self.col_rows * self.col_cols]
.copy_from_slice(gc_data);
}
let gi = col2im_3d(
&grad_cols,
batch,
self.in_channels,
d,
h,
w,
kd,
kh,
kw,
sd,
sh,
sw,
pd,
ph,
pw,
self.d_out,
self.h_out,
self.w_out,
);
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
let mut grads = vec![grad_input, grad_weight];
if self.bias.is_some() {
grads.push(grad_bias);
}
Ok(grads)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"Conv3dBackward"
}
}
#[derive(Debug)]
pub struct ConvTranspose1d<T: Float> {
weight: Parameter<T>,
bias: Option<Parameter<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
output_padding: usize,
training: bool,
}
impl<T: Float> ConvTranspose1d<T> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
output_padding: usize,
bias: bool,
) -> FerrotorchResult<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "in_channels and out_channels must be > 0".into(),
});
}
if kernel_size == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "kernel_size must be > 0".into(),
});
}
if stride == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "stride must be > 0".into(),
});
}
if output_padding >= stride {
return Err(FerrotorchError::InvalidArgument {
message: "output_padding must be strictly less than stride".into(),
});
}
let mut weight = Parameter::zeros(&[in_channels, out_channels, kernel_size])?;
kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
let bias_param = if bias {
let mut b = Parameter::zeros(&[out_channels])?;
zeros_init(&mut b)?;
Some(b)
} else {
None
};
Ok(Self {
weight,
bias: bias_param,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
training: true,
})
}
pub fn num_parameters(&self) -> usize {
let w = self.in_channels * self.out_channels * self.kernel_size;
let b = if self.bias.is_some() {
self.out_channels
} else {
0
};
w + b
}
}
impl<T: Float> Module<T> for ConvTranspose1d<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let _autocast_cat = autocast_guard("conv_transpose1d");
if input.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ConvTranspose1d expects 3-D input [B, C, L], got {:?}",
input.shape()
),
});
}
let batch = input.shape()[0];
let c_in = input.shape()[1];
let length = input.shape()[2];
if c_in != self.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"ConvTranspose1d: expected {} input channels, got {}",
self.in_channels, c_in
),
});
}
let k = self.kernel_size;
let s = self.stride;
let p = self.padding;
let op = self.output_padding;
let input_device = input.device();
let input_data = input.data_vec()?;
let (upsampled, _h_up, w_up) = stride_insert_zeros(&input_data, batch, c_in, 1, length, 1, s);
let weight_data = self.weight.data_vec()?;
let flipped = flip_kernel(&weight_data, self.in_channels, self.out_channels, 1, k);
let internal_pad_w = k - 1 - p;
let (cols, col_rows, col_cols) = im2col(
&upsampled,
batch,
c_in,
1,
w_up,
1,
k,
1,
1,
0,
internal_pad_w,
);
let w_out_base = (w_up + 2 * internal_pad_w - k) + 1;
let l_out = w_out_base + op;
let flipped_2d = Tensor::from_storage(
TensorStorage::cpu(flipped),
vec![self.out_channels, col_rows],
false,
)?;
let zero = <T as num_traits::Zero>::zero();
let mut output = vec![zero; batch * self.out_channels * l_out];
for b in 0..batch {
let col_start = b * col_rows * col_cols;
let col_end = col_start + col_rows * col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(cols[col_start..col_end].to_vec()),
vec![col_rows, col_cols],
false,
)?;
let out_b = mm(&flipped_2d, &cols_b)?;
let out_data = out_b.data()?;
let out_start = b * self.out_channels * l_out;
for c in 0..self.out_channels {
for ow in 0..w_out_base {
output[out_start + c * l_out + ow] = out_data[c * w_out_base + ow];
}
}
}
if let Some(ref bias) = self.bias {
let bias_data = bias.data_vec()?;
for b in 0..batch {
for c in 0..self.out_channels {
let bval = bias_data[c];
for l in 0..l_out {
output[b * self.out_channels * l_out + c * l_out + l] += bval;
}
}
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, self.out_channels, l_out],
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let grad_fn = Arc::new(ConvTranspose1dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
_output_padding: self.output_padding,
l_out,
});
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)?
.to(input_device) } else {
result.to(input_device)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.weight];
if let Some(ref b) = self.bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = vec![&mut self.weight];
if let Some(ref mut b) = self.bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![("weight".to_string(), &self.weight)];
if let Some(ref b) = self.bias {
params.push(("bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
struct ConvTranspose1dBackward<T: Float> {
input: Tensor<T>,
weight: Tensor<T>,
bias: Option<Tensor<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
_output_padding: usize,
l_out: usize,
}
impl<T: Float> GradFn<T> for ConvTranspose1dBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data_vec()?;
let batch = self.input.shape()[0];
let l_in = self.input.shape()[2];
let k = self.kernel_size;
let s = self.stride;
let p = self.padding;
let grad_input = if self.input.requires_grad() {
let weight_data = self.weight.data_vec()?;
let col_rows = self.out_channels * k;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![self.in_channels, col_rows],
false,
)?;
let (go_cols, _go_col_rows, go_col_cols) = im2col(
&go_data,
batch,
self.out_channels,
1,
self.l_out,
1,
k,
1,
s,
0,
p,
);
let zero = <T as num_traits::Zero>::zero();
let mut gi = vec![zero; batch * self.in_channels * l_in];
for b in 0..batch {
let col_start = b * col_rows * go_col_cols;
let col_end = col_start + col_rows * go_col_cols;
let go_cols_b = Tensor::from_storage(
TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
vec![col_rows, go_col_cols],
false,
)?;
let gi_b = mm(&weight_2d, &go_cols_b)?;
let gi_data = gi_b.data()?;
let out_start = b * self.in_channels * l_in;
let copy_len = self.in_channels * l_in;
gi[out_start..out_start + copy_len].copy_from_slice(&gi_data[..copy_len]);
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
let grad_weight = if self.weight.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let weight_numel = self.in_channels * self.out_channels * k;
let mut gw = vec![zero; weight_numel];
let input_data = self.input.data_vec()?;
for b in 0..batch {
for ci in 0..self.in_channels {
for co in 0..self.out_channels {
for dw in 0..k {
let mut acc = zero;
for il in 0..l_in {
let ow = il * s + dw;
if ow >= p && (ow - p) < self.l_out {
let go_idx = b * self.out_channels * self.l_out
+ co * self.l_out
+ (ow - p);
let in_idx =
b * self.in_channels * l_in + ci * l_in + il;
acc += input_data[in_idx] * go_data[go_idx];
}
}
gw[ci * self.out_channels * k + co * k + dw] += acc;
}
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gw),
vec![self.in_channels, self.out_channels, k],
false,
)?)
} else {
None
};
let grad_bias = match &self.bias {
Some(b) if b.requires_grad() => {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; self.out_channels];
for batch_idx in 0..batch {
for c in 0..self.out_channels {
for l in 0..self.l_out {
gb[c] += go_data
[batch_idx * self.out_channels * self.l_out + c * self.l_out + l];
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gb),
vec![self.out_channels],
false,
)?)
}
_ => None,
};
let mut grads = vec![grad_input, grad_weight];
if self.bias.is_some() {
grads.push(grad_bias);
}
Ok(grads)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"ConvTranspose1dBackward"
}
}
#[derive(Debug)]
pub struct ConvTranspose3d<T: Float> {
weight: Parameter<T>,
bias: Option<Parameter<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize, usize),
stride: (usize, usize, usize),
padding: (usize, usize, usize),
output_padding: (usize, usize, usize),
training: bool,
}
impl<T: Float> ConvTranspose3d<T> {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize, usize),
stride: (usize, usize, usize),
padding: (usize, usize, usize),
output_padding: (usize, usize, usize),
bias: bool,
) -> FerrotorchResult<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "in_channels and out_channels must be > 0".into(),
});
}
if kernel_size.0 == 0 || kernel_size.1 == 0 || kernel_size.2 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "kernel_size must be > 0 in all dimensions".into(),
});
}
if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "stride must be > 0 in all dimensions".into(),
});
}
if output_padding.0 >= stride.0
|| output_padding.1 >= stride.1
|| output_padding.2 >= stride.2
{
return Err(FerrotorchError::InvalidArgument {
message: "output_padding must be strictly less than stride in all dimensions"
.into(),
});
}
let (kd, kh, kw) = kernel_size;
let mut weight = Parameter::zeros(&[in_channels, out_channels, kd, kh, kw])?;
kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
let bias_param = if bias {
let mut b = Parameter::zeros(&[out_channels])?;
zeros_init(&mut b)?;
Some(b)
} else {
None
};
Ok(Self {
weight,
bias: bias_param,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
training: true,
})
}
pub fn num_parameters(&self) -> usize {
let w = self.in_channels
* self.out_channels
* self.kernel_size.0
* self.kernel_size.1
* self.kernel_size.2;
let b = if self.bias.is_some() {
self.out_channels
} else {
0
};
w + b
}
}
fn stride_insert_zeros_3d<T: Float>(
input: &[T],
batch: usize,
channels: usize,
d: usize,
h: usize,
w: usize,
stride_d: usize,
stride_h: usize,
stride_w: usize,
) -> (Vec<T>, usize, usize, usize) {
let d_up = (d - 1) * stride_d + 1;
let h_up = (h - 1) * stride_h + 1;
let w_up = (w - 1) * stride_w + 1;
let zero = <T as num_traits::Zero>::zero();
let mut out = vec![zero; batch * channels * d_up * h_up * w_up];
for b in 0..batch {
for c in 0..channels {
for id in 0..d {
for ih in 0..h {
for iw in 0..w {
let od = id * stride_d;
let oh = ih * stride_h;
let ow = iw * stride_w;
out[b * channels * d_up * h_up * w_up
+ c * d_up * h_up * w_up
+ od * h_up * w_up
+ oh * w_up
+ ow] = input[b * channels * d * h * w
+ c * d * h * w
+ id * h * w
+ ih * w
+ iw];
}
}
}
}
}
(out, d_up, h_up, w_up)
}
fn flip_kernel_3d<T: Float>(
kernel: &[T],
c_in: usize,
c_out: usize,
kd: usize,
kh: usize,
kw: usize,
) -> Vec<T> {
let zero = <T as num_traits::Zero>::zero();
let mut flipped = vec![zero; c_out * c_in * kd * kh * kw];
for ci in 0..c_in {
for co in 0..c_out {
for dd in 0..kd {
for dh in 0..kh {
for dw in 0..kw {
let src = ci * c_out * kd * kh * kw
+ co * kd * kh * kw
+ dd * kh * kw
+ dh * kw
+ dw;
let dst = co * c_in * kd * kh * kw
+ ci * kd * kh * kw
+ (kd - 1 - dd) * kh * kw
+ (kh - 1 - dh) * kw
+ (kw - 1 - dw);
flipped[dst] = kernel[src];
}
}
}
}
}
flipped
}
impl<T: Float> Module<T> for ConvTranspose3d<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let _autocast_cat = autocast_guard("conv_transpose3d");
if input.ndim() != 5 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ConvTranspose3d expects 5-D input [B, C, D, H, W], got {:?}",
input.shape()
),
});
}
let batch = input.shape()[0];
let c_in = input.shape()[1];
let d = input.shape()[2];
let h = input.shape()[3];
let w = input.shape()[4];
if c_in != self.in_channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"ConvTranspose3d: expected {} input channels, got {}",
self.in_channels, c_in
),
});
}
let (kd, kh, kw) = self.kernel_size;
let (sd, sh, sw) = self.stride;
let (pd, ph, pw) = self.padding;
let (opd, oph, opw) = self.output_padding;
let input_device = input.device();
let input_data = input.data_vec()?;
let (upsampled, d_up, h_up, w_up) =
stride_insert_zeros_3d(&input_data, batch, c_in, d, h, w, sd, sh, sw);
let weight_data = self.weight.data_vec()?;
let flipped = flip_kernel_3d(&weight_data, self.in_channels, self.out_channels, kd, kh, kw);
let internal_pad_d = kd - 1 - pd;
let internal_pad_h = kh - 1 - ph;
let internal_pad_w = kw - 1 - pw;
let (cols, col_rows, col_cols) = im2col_3d(
&upsampled,
batch,
c_in,
d_up,
h_up,
w_up,
kd,
kh,
kw,
1,
1,
1,
internal_pad_d,
internal_pad_h,
internal_pad_w,
);
let d_out_base = (d_up + 2 * internal_pad_d - kd) + 1;
let h_out_base = (h_up + 2 * internal_pad_h - kh) + 1;
let w_out_base = (w_up + 2 * internal_pad_w - kw) + 1;
let d_out = d_out_base + opd;
let h_out = h_out_base + oph;
let w_out = w_out_base + opw;
let flipped_2d = Tensor::from_storage(
TensorStorage::cpu(flipped),
vec![self.out_channels, col_rows],
false,
)?;
let zero = <T as num_traits::Zero>::zero();
let spatial_out = d_out * h_out * w_out;
let spatial_base = d_out_base * h_out_base * w_out_base;
let mut output = vec![zero; batch * self.out_channels * spatial_out];
for b in 0..batch {
let col_start = b * col_rows * col_cols;
let col_end = col_start + col_rows * col_cols;
let cols_b = Tensor::from_storage(
TensorStorage::cpu(cols[col_start..col_end].to_vec()),
vec![col_rows, col_cols],
false,
)?;
let out_b = mm(&flipped_2d, &cols_b)?;
let out_data = out_b.data()?;
let out_start = b * self.out_channels * spatial_out;
for c in 0..self.out_channels {
for od in 0..d_out_base {
for oh in 0..h_out_base {
for ow in 0..w_out_base {
output[out_start
+ c * spatial_out
+ od * h_out * w_out
+ oh * w_out
+ ow] = out_data
[c * spatial_base + od * h_out_base * w_out_base + oh * w_out_base + ow];
}
}
}
}
}
if let Some(ref bias) = self.bias {
let bias_data = bias.data_vec()?;
for b in 0..batch {
for c in 0..self.out_channels {
let bval = bias_data[c];
for s in 0..spatial_out {
output[b * self.out_channels * spatial_out + c * spatial_out + s] += bval;
}
}
}
}
let result = Tensor::from_storage(
TensorStorage::cpu(output),
vec![batch, self.out_channels, d_out, h_out, w_out],
false,
)?;
if is_grad_enabled()
&& (input.requires_grad()
|| self.weight.requires_grad()
|| self.bias.as_ref().is_some_and(|b| b.requires_grad()))
{
let grad_fn = Arc::new(ConvTranspose3dBackward {
input: input.clone(),
weight: self.weight.tensor().clone(),
bias: self.bias.as_ref().map(|b| b.tensor().clone()),
in_channels: self.in_channels,
out_channels: self.out_channels,
kernel_size: self.kernel_size,
stride: self.stride,
padding: self.padding,
_output_padding: self.output_padding,
d_out,
h_out,
w_out,
});
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)?
.to(input_device) } else {
result.to(input_device)
}
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.weight];
if let Some(ref b) = self.bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params = vec![&mut self.weight];
if let Some(ref mut b) = self.bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![("weight".to_string(), &self.weight)];
if let Some(ref b) = self.bias {
params.push(("bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
#[derive(Debug)]
struct ConvTranspose3dBackward<T: Float> {
input: Tensor<T>,
weight: Tensor<T>,
bias: Option<Tensor<T>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize, usize),
stride: (usize, usize, usize),
padding: (usize, usize, usize),
_output_padding: (usize, usize, usize),
d_out: usize,
h_out: usize,
w_out: usize,
}
impl<T: Float> GradFn<T> for ConvTranspose3dBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let go_data = grad_output.data_vec()?;
let batch = self.input.shape()[0];
let d_in = self.input.shape()[2];
let h_in = self.input.shape()[3];
let w_in = self.input.shape()[4];
let (kd, kh, kw) = self.kernel_size;
let (sd, sh, sw) = self.stride;
let (pd, ph, pw) = self.padding;
let spatial_out = self.d_out * self.h_out * self.w_out;
let grad_input = if self.input.requires_grad() {
let weight_data = self.weight.data_vec()?;
let col_rows = self.out_channels * kd * kh * kw;
let weight_2d = Tensor::from_storage(
TensorStorage::cpu(weight_data),
vec![self.in_channels, col_rows],
false,
)?;
let (go_cols, _go_col_rows, go_col_cols) = im2col_3d(
&go_data,
batch,
self.out_channels,
self.d_out,
self.h_out,
self.w_out,
kd,
kh,
kw,
sd,
sh,
sw,
pd,
ph,
pw,
);
let zero = <T as num_traits::Zero>::zero();
let spatial_in = d_in * h_in * w_in;
let mut gi = vec![zero; batch * self.in_channels * spatial_in];
for b in 0..batch {
let col_start = b * col_rows * go_col_cols;
let col_end = col_start + col_rows * go_col_cols;
let go_cols_b = Tensor::from_storage(
TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
vec![col_rows, go_col_cols],
false,
)?;
let gi_b = mm(&weight_2d, &go_cols_b)?;
let gi_data = gi_b.data()?;
let out_start = b * self.in_channels * spatial_in;
let copy_len = self.in_channels * spatial_in;
gi[out_start..out_start + copy_len].copy_from_slice(&gi_data[..copy_len]);
}
Some(Tensor::from_storage(
TensorStorage::cpu(gi),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
let grad_weight = if self.weight.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let weight_numel = self.in_channels * self.out_channels * kd * kh * kw;
let mut gw = vec![zero; weight_numel];
let input_data = self.input.data_vec()?;
let spatial_in = d_in * h_in * w_in;
for b in 0..batch {
for ci in 0..self.in_channels {
for co in 0..self.out_channels {
for dd in 0..kd {
for dh in 0..kh {
for dw in 0..kw {
let mut acc = zero;
for id in 0..d_in {
for ih in 0..h_in {
for iw in 0..w_in {
let od = id * sd + dd;
let oh = ih * sh + dh;
let ow = iw * sw + dw;
if od >= pd
&& oh >= ph
&& ow >= pw
&& (od - pd) < self.d_out
&& (oh - ph) < self.h_out
&& (ow - pw) < self.w_out
{
let go_idx = b * self.out_channels * spatial_out
+ co * spatial_out
+ (od - pd) * self.h_out * self.w_out
+ (oh - ph) * self.w_out
+ (ow - pw);
let in_idx =
b * self.in_channels * spatial_in
+ ci * spatial_in
+ id * h_in * w_in
+ ih * w_in
+ iw;
acc += input_data[in_idx] * go_data[go_idx];
}
}
}
}
gw[ci * self.out_channels * kd * kh * kw
+ co * kd * kh * kw
+ dd * kh * kw
+ dh * kw
+ dw] += acc;
}
}
}
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gw),
vec![self.in_channels, self.out_channels, kd, kh, kw],
false,
)?)
} else {
None
};
let grad_bias = match &self.bias {
Some(b) if b.requires_grad() => {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; self.out_channels];
for batch_idx in 0..batch {
for c in 0..self.out_channels {
for s in 0..spatial_out {
gb[c] += go_data
[batch_idx * self.out_channels * spatial_out + c * spatial_out + s];
}
}
}
Some(Tensor::from_storage(
TensorStorage::cpu(gb),
vec![self.out_channels],
false,
)?)
}
_ => None,
};
let mut grads = vec![grad_input, grad_weight];
if self.bias.is_some() {
grads.push(grad_bias);
}
Ok(grads)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"ConvTranspose3dBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::module::Module;
fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: actual={a} expected={e} (diff {})",
(a - e).abs()
);
}
}
#[test]
fn test_output_shape_no_padding() {
let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
let input = t(&vec![0.0; 25], &[1, 1, 5, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 3, 3]);
}
#[test]
fn test_output_shape_with_padding() {
let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (1, 1), true).unwrap();
let input = t(&vec![0.0; 2 * 3 * 8 * 8], &[2, 3, 8, 8]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 16, 8, 8]);
}
#[test]
fn test_output_shape_with_stride() {
let conv = Conv2d::<f32>::new(1, 4, (3, 3), (2, 2), (0, 0), false).unwrap();
let input = t(&vec![0.0; 36], &[1, 1, 6, 6]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4, 2, 2]);
}
#[test]
fn test_1x1_conv_equals_linear() {
let weight_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
let input_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let weight_param = Parameter::from_slice(&weight_data, &[3, 2, 1, 1]).unwrap();
let conv = Conv2d {
weight: weight_param,
bias: None,
in_channels: 2,
out_channels: 3,
kernel_size: (1, 1),
stride: (1, 1),
padding: (0, 0),
training: false,
};
let input = t(&input_data, &[1, 2, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 3, 2, 2]);
let out = output.data().unwrap();
let expected = [
11.0, 14.0, 17.0, 20.0, 23.0, 30.0, 37.0, 44.0, 35.0, 46.0, 57.0, 68.0, ];
assert_close(out, &expected, 1e-5);
}
#[test]
fn test_bias_addition() {
let weight_data = vec![1.0f32]; let bias_data = vec![10.0f32];
let conv = Conv2d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 1, 1]).unwrap(),
bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: (1, 1),
stride: (1, 1),
padding: (0, 0),
training: false,
};
let input = t(&[2.0, 3.0, 4.0, 5.0], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_close(output.data().unwrap(), &[12.0, 13.0, 14.0, 15.0], 1e-5);
}
#[test]
fn test_backward_produces_correct_shapes() {
let weight_data = vec![1.0f32; 2 * 1 * 3 * 3]; let input_data = vec![1.0f32; 1 * 1 * 5 * 5]; let bias_data = vec![0.0f32; 2];
let weight_param = Parameter::from_slice(&weight_data, &[2, 1, 3, 3]).unwrap();
let bias_param = Parameter::from_slice(&bias_data, &[2]).unwrap();
let conv = Conv2d {
weight: weight_param,
bias: Some(bias_param),
in_channels: 1,
out_channels: 2,
kernel_size: (3, 3),
stride: (1, 1),
padding: (0, 0),
training: false,
};
let input = leaf(&input_data, &[1, 1, 5, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 3, 3]);
assert!(output.grad_fn().is_some());
assert_eq!(output.grad_fn().unwrap().name(), "Conv2dBackward");
let grad_output = t(&vec![1.0; 2 * 3 * 3], &[1, 2, 3, 3]);
let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
assert!(grads[0].is_some());
assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 5, 5]);
assert!(grads[1].is_some());
assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 3, 3]);
assert!(grads[2].is_some());
assert_eq!(grads[2].as_ref().unwrap().shape(), &[2]);
}
#[test]
fn test_parameter_count_with_bias() {
let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (0, 0), true).unwrap();
assert_eq!(conv.num_parameters(), 448);
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_parameter_count_without_bias() {
let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (0, 0), false).unwrap();
assert_eq!(conv.num_parameters(), 432);
assert_eq!(conv.parameters().len(), 1);
}
#[test]
fn test_named_parameters() {
let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), true).unwrap();
let named = conv.named_parameters();
assert_eq!(named.len(), 2);
assert_eq!(named[0].0, "weight");
assert_eq!(named[1].0, "bias");
}
#[test]
fn test_train_eval() {
let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
assert!(conv.is_training());
conv.eval();
assert!(!conv.is_training());
conv.train();
assert!(conv.is_training());
}
#[test]
fn test_invalid_input_ndim() {
let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
let input = t(&[1.0, 2.0, 3.0], &[3]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_channel_mismatch() {
let conv = Conv2d::<f32>::new(3, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
let input = t(&vec![0.0; 1 * 1 * 5 * 5], &[1, 1, 5, 5]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_zero_channels_rejected() {
assert!(Conv2d::<f32>::new(0, 16, (3, 3), (1, 1), (0, 0), false).is_err());
assert!(Conv2d::<f32>::new(3, 0, (3, 3), (1, 1), (0, 0), false).is_err());
}
#[test]
fn test_zero_kernel_rejected() {
assert!(Conv2d::<f32>::new(1, 1, (0, 3), (1, 1), (0, 0), false).is_err());
}
#[test]
fn test_zero_stride_rejected() {
assert!(Conv2d::<f32>::new(1, 1, (3, 3), (0, 1), (0, 0), false).is_err());
}
#[test]
fn test_im2col_basic() {
#[rustfmt::skip]
let input: Vec<f32> = vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
];
let (cols, rows, n_cols) = im2col(&input, 1, 1, 3, 3, 2, 2, 1, 1, 0, 0);
assert_eq!(rows, 4); assert_eq!(n_cols, 4);
assert_close(
&cols,
&[
1.0, 2.0, 4.0, 5.0, 2.0, 3.0, 5.0, 6.0, 4.0, 5.0, 7.0, 8.0, 5.0, 6.0, 8.0, 9.0, ],
1e-7,
);
}
#[test]
fn test_col2im_roundtrip_no_overlap() {
#[rustfmt::skip]
let input: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let (cols, _rows, _n_cols) = im2col(&input, 1, 1, 4, 4, 2, 2, 2, 2, 0, 0);
let recovered = col2im(&cols, 1, 1, 4, 4, 2, 2, 2, 2, 0, 0, 2, 2);
assert_close(&recovered, &input, 1e-7);
}
#[test]
fn test_3x3_conv_forward() {
#[rustfmt::skip]
let input_data: Vec<f32> = vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
];
#[rustfmt::skip]
let weight_data: Vec<f32> = vec![
1.0, 0.0, -1.0,
1.0, 0.0, -1.0,
1.0, 0.0, -1.0,
];
let conv = Conv2d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 3, 3]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: (3, 3),
stride: (1, 1),
padding: (0, 0),
training: false,
};
let input = t(&input_data, &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 1, 1]);
assert_close(output.data().unwrap(), &[-6.0], 1e-5);
}
#[test]
fn test_padding_preserves_spatial_size() {
let weight_data = vec![0.0f32; 9];
let mut weight_data_center = weight_data;
weight_data_center[4] = 1.0;
let conv = Conv2d {
weight: Parameter::from_slice(&weight_data_center, &[1, 1, 3, 3]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: (3, 3),
stride: (1, 1),
padding: (1, 1),
training: false,
};
let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let input = t(&input_data, &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 3, 3]);
assert_close(output.data().unwrap(), &input_data, 1e-5);
}
#[test]
fn test_conv1d_output_shape_no_padding() {
let conv = Conv1d::<f32>::new(1, 4, 3, 1, 0, false).unwrap();
let input = t(&vec![0.0; 10], &[1, 1, 10]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4, 8]);
}
#[test]
fn test_conv1d_output_shape_with_padding() {
let conv = Conv1d::<f32>::new(3, 8, 3, 1, 1, true).unwrap();
let input = t(&vec![0.0; 2 * 3 * 16], &[2, 3, 16]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 8, 16]);
}
#[test]
fn test_conv1d_output_shape_with_stride() {
let conv = Conv1d::<f32>::new(1, 2, 3, 2, 0, false).unwrap();
let input = t(&vec![0.0; 10], &[1, 1, 10]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 4]);
}
#[test]
fn test_conv1d_1x1_kernel_correctness() {
let weight_data = vec![3.0f32, 5.0];
let conv = Conv1d {
weight: Parameter::from_slice(&weight_data, &[2, 1, 1]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 2,
kernel_size: 1,
stride: 1,
padding: 0,
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 4]);
assert_close(
output.data().unwrap(),
&[3.0, 6.0, 9.0, 12.0, 5.0, 10.0, 15.0, 20.0],
1e-5,
);
}
#[test]
fn test_conv1d_3_kernel_forward() {
let conv = Conv1d {
weight: Parameter::from_slice(&[1.0f32, 0.0, -1.0], &[1, 1, 3]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: 3,
stride: 1,
padding: 0,
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 3]);
assert_close(output.data().unwrap(), &[-2.0, -2.0, -2.0], 1e-5);
}
#[test]
fn test_conv1d_bias() {
let conv = Conv1d {
weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1]).unwrap(),
bias: Some(Parameter::from_slice(&[10.0f32], &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: 1,
stride: 1,
padding: 0,
training: false,
};
let input = t(&[2.0, 3.0, 4.0], &[1, 1, 3]);
let output = conv.forward(&input).unwrap();
assert_close(output.data().unwrap(), &[12.0, 13.0, 14.0], 1e-5);
}
#[test]
fn test_conv1d_invalid_ndim() {
let conv = Conv1d::<f32>::new(1, 1, 3, 1, 0, false).unwrap();
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv1d_channel_mismatch() {
let conv = Conv1d::<f32>::new(3, 1, 3, 1, 0, false).unwrap();
let input = t(&vec![0.0; 10], &[1, 1, 10]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv1d_zero_channels_rejected() {
assert!(Conv1d::<f32>::new(0, 4, 3, 1, 0, false).is_err());
assert!(Conv1d::<f32>::new(1, 0, 3, 1, 0, false).is_err());
}
#[test]
fn test_conv1d_zero_kernel_rejected() {
assert!(Conv1d::<f32>::new(1, 1, 0, 1, 0, false).is_err());
}
#[test]
fn test_conv1d_zero_stride_rejected() {
assert!(Conv1d::<f32>::new(1, 1, 3, 0, 0, false).is_err());
}
#[test]
fn test_conv1d_parameter_count() {
let conv = Conv1d::<f32>::new(3, 8, 5, 1, 0, true).unwrap();
assert_eq!(conv.num_parameters(), 128);
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_conv_transpose2d_output_shape_basic() {
let conv =
ConvTranspose2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
let input = t(&vec![0.0; 9], &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5, 5]);
}
#[test]
fn test_conv_transpose2d_output_shape_stride2() {
let conv =
ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), (0, 0), false).unwrap();
let input = t(&vec![0.0; 4], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5, 5]);
}
#[test]
fn test_conv_transpose2d_output_shape_with_padding() {
let conv =
ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (1, 1), (0, 0), false).unwrap();
let input = t(&vec![0.0; 9], &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5, 5]);
}
#[test]
fn test_conv_transpose2d_output_shape_with_output_padding() {
let conv =
ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false).unwrap();
let input = t(&vec![0.0; 9], &[1, 1, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 6, 6]);
}
#[test]
fn test_conv_transpose2d_stride2_upsamples() {
let conv =
ConvTranspose2d::<f32>::new(1, 1, (2, 2), (2, 2), (0, 0), (0, 0), false).unwrap();
let input = t(&vec![0.0; 1 * 1 * 4 * 4], &[1, 1, 4, 4]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 8, 8]);
}
#[test]
fn test_conv_transpose2d_stride2_upsamples_multichannel() {
let conv =
ConvTranspose2d::<f32>::new(8, 16, (2, 2), (2, 2), (0, 0), (0, 0), true).unwrap();
let input = t(&vec![0.0; 2 * 8 * 4 * 4], &[2, 8, 4, 4]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 16, 8, 8]);
}
#[test]
fn test_conv_transpose2d_1x1_kernel() {
let weight_data = vec![3.0f32, 7.0]; let conv = ConvTranspose2d {
weight: Parameter::from_slice(&weight_data, &[1, 2, 1, 1]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 2,
kernel_size: (1, 1),
stride: (1, 1),
padding: (0, 0),
output_padding: (0, 0),
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 2, 2]);
assert_close(
output.data().unwrap(),
&[3.0, 6.0, 9.0, 12.0, 7.0, 14.0, 21.0, 28.0],
1e-5,
);
}
#[test]
fn test_conv_transpose2d_stride2_correctness() {
let weight_data = vec![1.0f32; 4]; let conv = ConvTranspose2d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 2, 2]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: (2, 2),
stride: (2, 2),
padding: (0, 0),
output_padding: (0, 0),
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 4, 4]);
#[rustfmt::skip]
let expected = [
1.0, 1.0, 2.0, 2.0,
1.0, 1.0, 2.0, 2.0,
3.0, 3.0, 4.0, 4.0,
3.0, 3.0, 4.0, 4.0,
];
assert_close(output.data().unwrap(), &expected, 1e-5);
}
#[test]
fn test_conv_transpose2d_bias() {
let weight_data = vec![1.0f32]; let bias_data = vec![5.0f32];
let conv = ConvTranspose2d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 1, 1]).unwrap(),
bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: (1, 1),
stride: (1, 1),
padding: (0, 0),
output_padding: (0, 0),
training: false,
};
let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_close(output.data().unwrap(), &[6.0, 7.0, 8.0, 9.0], 1e-5);
}
#[test]
fn test_conv_transpose2d_invalid_ndim() {
let conv =
ConvTranspose2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv_transpose2d_channel_mismatch() {
let conv =
ConvTranspose2d::<f32>::new(3, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
let input = t(&vec![0.0; 1 * 1 * 5 * 5], &[1, 1, 5, 5]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv_transpose2d_zero_channels_rejected() {
assert!(ConvTranspose2d::<f32>::new(0, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).is_err());
assert!(ConvTranspose2d::<f32>::new(1, 0, (3, 3), (1, 1), (0, 0), (0, 0), false).is_err());
}
#[test]
fn test_conv_transpose2d_output_padding_too_large() {
assert!(ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), (2, 2), false).is_err());
}
#[test]
fn test_conv_transpose2d_parameter_count() {
let conv =
ConvTranspose2d::<f32>::new(8, 16, (3, 3), (2, 2), (1, 1), (0, 0), true).unwrap();
assert_eq!(conv.num_parameters(), 1168);
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_conv3d_output_shape_no_padding() {
let conv = Conv3d::<f32>::new(1, 4, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).unwrap();
let input = t(&vec![0.0; 1 * 1 * 5 * 5 * 5], &[1, 1, 5, 5, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4, 3, 3, 3]);
}
#[test]
fn test_conv3d_output_shape_with_padding() {
let conv = Conv3d::<f32>::new(3, 16, (3, 3, 3), (1, 1, 1), (1, 1, 1), true).unwrap();
let input = t(&vec![0.0; 2 * 3 * 8 * 8 * 8], &[2, 3, 8, 8, 8]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 16, 8, 8, 8]);
}
#[test]
fn test_conv3d_output_shape_with_stride() {
let conv = Conv3d::<f32>::new(1, 4, (3, 3, 3), (2, 2, 2), (0, 0, 0), false).unwrap();
let input = t(&vec![0.0; 1 * 1 * 6 * 6 * 6], &[1, 1, 6, 6, 6]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4, 2, 2, 2]);
}
#[test]
fn test_conv3d_1x1x1_kernel_correctness() {
let weight_data = vec![3.0f32, 5.0];
let conv = Conv3d {
weight: Parameter::from_slice(&weight_data, &[2, 1, 1, 1, 1]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 2,
kernel_size: (1, 1, 1),
stride: (1, 1, 1),
padding: (0, 0, 0),
training: false,
};
let input = t(&[1.0, 2.0], &[1, 1, 2, 1, 1]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 2, 1, 1]);
assert_close(
output.data().unwrap(),
&[3.0, 6.0, 5.0, 10.0],
1e-5,
);
}
#[test]
fn test_conv3d_3x3x3_kernel_forward() {
let input_data = vec![1.0f32; 27];
let weight_data = vec![1.0f32; 27];
let conv = Conv3d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 3, 3, 3]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: (3, 3, 3),
stride: (1, 1, 1),
padding: (0, 0, 0),
training: false,
};
let input = t(&input_data, &[1, 1, 3, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 1, 1, 1]);
assert_close(output.data().unwrap(), &[27.0], 1e-5);
}
#[test]
fn test_conv3d_bias() {
let conv = Conv3d {
weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1, 1, 1]).unwrap(),
bias: Some(Parameter::from_slice(&[10.0f32], &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: (1, 1, 1),
stride: (1, 1, 1),
padding: (0, 0, 0),
training: false,
};
let input = t(&[2.0, 3.0], &[1, 1, 2, 1, 1]);
let output = conv.forward(&input).unwrap();
assert_close(output.data().unwrap(), &[12.0, 13.0], 1e-5);
}
#[test]
fn test_conv3d_backward_produces_correct_shapes() {
let weight_data = vec![1.0f32; 2 * 1 * 3 * 3 * 3]; let input_data = vec![1.0f32; 1 * 1 * 5 * 5 * 5]; let bias_data = vec![0.0f32; 2];
let conv = Conv3d {
weight: Parameter::from_slice(&weight_data, &[2, 1, 3, 3, 3]).unwrap(),
bias: Some(Parameter::from_slice(&bias_data, &[2]).unwrap()),
in_channels: 1,
out_channels: 2,
kernel_size: (3, 3, 3),
stride: (1, 1, 1),
padding: (0, 0, 0),
training: false,
};
let input = leaf(&input_data, &[1, 1, 5, 5, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 3, 3, 3]);
assert!(output.grad_fn().is_some());
assert_eq!(output.grad_fn().unwrap().name(), "Conv3dBackward");
let grad_output = t(&vec![1.0; 2 * 3 * 3 * 3], &[1, 2, 3, 3, 3]);
let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
assert!(grads[0].is_some());
assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 5, 5, 5]);
assert!(grads[1].is_some());
assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 3, 3, 3]);
assert!(grads[2].is_some());
assert_eq!(grads[2].as_ref().unwrap().shape(), &[2]);
}
#[test]
fn test_conv3d_invalid_ndim() {
let conv = Conv3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).unwrap();
let input = t(&vec![0.0; 25], &[1, 1, 5, 5]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv3d_channel_mismatch() {
let conv = Conv3d::<f32>::new(3, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).unwrap();
let input = t(&vec![0.0; 1 * 1 * 5 * 5 * 5], &[1, 1, 5, 5, 5]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv3d_zero_channels_rejected() {
assert!(Conv3d::<f32>::new(0, 16, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).is_err());
assert!(Conv3d::<f32>::new(3, 0, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).is_err());
}
#[test]
fn test_conv3d_zero_kernel_rejected() {
assert!(Conv3d::<f32>::new(1, 1, (0, 3, 3), (1, 1, 1), (0, 0, 0), false).is_err());
}
#[test]
fn test_conv3d_zero_stride_rejected() {
assert!(Conv3d::<f32>::new(1, 1, (3, 3, 3), (0, 1, 1), (0, 0, 0), false).is_err());
}
#[test]
fn test_conv3d_parameter_count() {
let conv = Conv3d::<f32>::new(3, 8, (3, 3, 3), (1, 1, 1), (0, 0, 0), true).unwrap();
assert_eq!(conv.num_parameters(), 656);
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_conv3d_named_parameters() {
let conv = Conv3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), true).unwrap();
let named = conv.named_parameters();
assert_eq!(named.len(), 2);
assert_eq!(named[0].0, "weight");
assert_eq!(named[1].0, "bias");
}
#[test]
fn test_conv_transpose1d_output_shape_basic() {
let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 1, 0, 0, false).unwrap();
let input = t(&vec![0.0; 5], &[1, 1, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 7]);
}
#[test]
fn test_conv_transpose1d_output_shape_stride2() {
let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 2, 0, 0, false).unwrap();
let input = t(&vec![0.0; 3], &[1, 1, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 7]);
}
#[test]
fn test_conv_transpose1d_output_shape_with_padding() {
let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 2, 1, 0, false).unwrap();
let input = t(&vec![0.0; 5], &[1, 1, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 9]);
}
#[test]
fn test_conv_transpose1d_output_shape_with_output_padding() {
let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 2, 1, 1, false).unwrap();
let input = t(&vec![0.0; 5], &[1, 1, 5]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 10]);
}
#[test]
fn test_conv_transpose1d_1x1_kernel() {
let weight_data = vec![3.0f32, 7.0]; let conv = ConvTranspose1d {
weight: Parameter::from_slice(&weight_data, &[1, 2, 1]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 2,
kernel_size: 1,
stride: 1,
padding: 0,
output_padding: 0,
training: false,
};
let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 3]);
assert_close(
output.data().unwrap(),
&[3.0, 6.0, 9.0, 7.0, 14.0, 21.0],
1e-5,
);
}
#[test]
fn test_conv_transpose1d_stride2_correctness() {
let weight_data = vec![1.0f32; 2]; let conv = ConvTranspose1d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 2]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 1,
kernel_size: 2,
stride: 2,
padding: 0,
output_padding: 0,
training: false,
};
let input = t(&[1.0, 2.0], &[1, 1, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 4]);
assert_close(output.data().unwrap(), &[1.0, 1.0, 2.0, 2.0], 1e-5);
}
#[test]
fn test_conv_transpose1d_bias() {
let conv = ConvTranspose1d {
weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1]).unwrap(),
bias: Some(Parameter::from_slice(&[5.0f32], &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: 1,
stride: 1,
padding: 0,
output_padding: 0,
training: false,
};
let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
let output = conv.forward(&input).unwrap();
assert_close(output.data().unwrap(), &[6.0, 7.0, 8.0], 1e-5);
}
#[test]
fn test_conv_transpose1d_backward_produces_gradients() {
let weight_data = vec![1.0f32; 1 * 1 * 3]; let bias_data = vec![0.0f32; 1];
let conv = ConvTranspose1d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 3]).unwrap(),
bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: 3,
stride: 1,
padding: 0,
output_padding: 0,
training: false,
};
let input = leaf(&[1.0f32, 2.0, 3.0], &[1, 1, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5]);
assert!(output.grad_fn().is_some());
assert_eq!(output.grad_fn().unwrap().name(), "ConvTranspose1dBackward");
let grad_output = t(&vec![1.0; 5], &[1, 1, 5]);
let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
assert!(grads[0].is_some());
assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 3]);
assert!(grads[1].is_some());
assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 3]);
assert!(grads[2].is_some());
assert_eq!(grads[2].as_ref().unwrap().shape(), &[1]);
}
#[test]
fn test_conv_transpose1d_invalid_ndim() {
let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 1, 0, 0, false).unwrap();
let input = t(&vec![0.0; 4], &[1, 1, 2, 2]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv_transpose1d_channel_mismatch() {
let conv = ConvTranspose1d::<f32>::new(3, 1, 3, 1, 0, 0, false).unwrap();
let input = t(&vec![0.0; 10], &[1, 1, 10]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv_transpose1d_zero_channels_rejected() {
assert!(ConvTranspose1d::<f32>::new(0, 1, 3, 1, 0, 0, false).is_err());
assert!(ConvTranspose1d::<f32>::new(1, 0, 3, 1, 0, 0, false).is_err());
}
#[test]
fn test_conv_transpose1d_output_padding_too_large() {
assert!(ConvTranspose1d::<f32>::new(1, 1, 3, 2, 0, 2, false).is_err());
}
#[test]
fn test_conv_transpose1d_parameter_count() {
let conv = ConvTranspose1d::<f32>::new(8, 16, 5, 2, 1, 0, true).unwrap();
assert_eq!(conv.num_parameters(), 656);
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_conv_transpose3d_output_shape_basic() {
let conv = ConvTranspose3d::<f32>::new(
1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false,
)
.unwrap();
let input = t(&vec![0.0; 27], &[1, 1, 3, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5, 5, 5]);
}
#[test]
fn test_conv_transpose3d_output_shape_stride2() {
let conv = ConvTranspose3d::<f32>::new(
1, 1, (3, 3, 3), (2, 2, 2), (0, 0, 0), (0, 0, 0), false,
)
.unwrap();
let input = t(&vec![0.0; 8], &[1, 1, 2, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5, 5, 5]);
}
#[test]
fn test_conv_transpose3d_output_shape_with_padding() {
let conv = ConvTranspose3d::<f32>::new(
1, 1, (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0), false,
)
.unwrap();
let input = t(&vec![0.0; 27], &[1, 1, 3, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 5, 5, 5]);
}
#[test]
fn test_conv_transpose3d_output_shape_with_output_padding() {
let conv = ConvTranspose3d::<f32>::new(
1, 1, (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1), false,
)
.unwrap();
let input = t(&vec![0.0; 27], &[1, 1, 3, 3, 3]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 6, 6, 6]);
}
#[test]
fn test_conv_transpose3d_stride2_upsamples() {
let conv = ConvTranspose3d::<f32>::new(
1, 1, (2, 2, 2), (2, 2, 2), (0, 0, 0), (0, 0, 0), false,
)
.unwrap();
let input = t(&vec![0.0; 1 * 1 * 4 * 4 * 4], &[1, 1, 4, 4, 4]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 8, 8, 8]);
}
#[test]
fn test_conv_transpose3d_1x1x1_kernel() {
let weight_data = vec![3.0f32, 7.0]; let conv = ConvTranspose3d {
weight: Parameter::from_slice(&weight_data, &[1, 2, 1, 1, 1]).unwrap(),
bias: None,
in_channels: 1,
out_channels: 2,
kernel_size: (1, 1, 1),
stride: (1, 1, 1),
padding: (0, 0, 0),
output_padding: (0, 0, 0),
training: false,
};
let input = t(&[1.0, 2.0], &[1, 1, 2, 1, 1]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 2, 1, 1]);
assert_close(
output.data().unwrap(),
&[3.0, 6.0, 7.0, 14.0],
1e-5,
);
}
#[test]
fn test_conv_transpose3d_bias() {
let conv = ConvTranspose3d {
weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1, 1, 1]).unwrap(),
bias: Some(Parameter::from_slice(&[5.0f32], &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: (1, 1, 1),
stride: (1, 1, 1),
padding: (0, 0, 0),
output_padding: (0, 0, 0),
training: false,
};
let input = t(&[1.0, 2.0], &[1, 1, 2, 1, 1]);
let output = conv.forward(&input).unwrap();
assert_close(output.data().unwrap(), &[6.0, 7.0], 1e-5);
}
#[test]
fn test_conv_transpose3d_backward_produces_gradients() {
let weight_data = vec![1.0f32; 1 * 1 * 2 * 2 * 2]; let bias_data = vec![0.0f32; 1];
let conv = ConvTranspose3d {
weight: Parameter::from_slice(&weight_data, &[1, 1, 2, 2, 2]).unwrap(),
bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
in_channels: 1,
out_channels: 1,
kernel_size: (2, 2, 2),
stride: (1, 1, 1),
padding: (0, 0, 0),
output_padding: (0, 0, 0),
training: false,
};
let input = leaf(&vec![1.0f32; 8], &[1, 1, 2, 2, 2]);
let output = conv.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 1, 3, 3, 3]);
assert!(output.grad_fn().is_some());
assert_eq!(output.grad_fn().unwrap().name(), "ConvTranspose3dBackward");
let grad_output = t(&vec![1.0; 27], &[1, 1, 3, 3, 3]);
let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
assert!(grads[0].is_some());
assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 2, 2, 2]);
assert!(grads[1].is_some());
assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 2, 2, 2]);
assert!(grads[2].is_some());
assert_eq!(grads[2].as_ref().unwrap().shape(), &[1]);
}
#[test]
fn test_conv_transpose3d_invalid_ndim() {
let conv = ConvTranspose3d::<f32>::new(
1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false,
)
.unwrap();
let input = t(&vec![0.0; 25], &[1, 1, 5, 5]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv_transpose3d_channel_mismatch() {
let conv = ConvTranspose3d::<f32>::new(
3, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false,
)
.unwrap();
let input = t(&vec![0.0; 1 * 1 * 5 * 5 * 5], &[1, 1, 5, 5, 5]);
assert!(conv.forward(&input).is_err());
}
#[test]
fn test_conv_transpose3d_zero_channels_rejected() {
assert!(
ConvTranspose3d::<f32>::new(0, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
.is_err()
);
assert!(
ConvTranspose3d::<f32>::new(1, 0, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
.is_err()
);
}
#[test]
fn test_conv_transpose3d_output_padding_too_large() {
assert!(
ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (2, 2, 2), (0, 0, 0), (2, 2, 2), false)
.is_err()
);
}
#[test]
fn test_conv_transpose3d_parameter_count() {
let conv = ConvTranspose3d::<f32>::new(
8, 16, (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0), true,
)
.unwrap();
assert_eq!(conv.num_parameters(), 3472);
assert_eq!(conv.parameters().len(), 2);
}
}