use crate::activations::Activation;
use crate::device::Device;
use crate::error::{NnlError, Result};
use crate::layers::{Layer, TrainingMode, WeightInit};
use crate::tensor::Tensor;
use std::fmt;
#[derive(Debug)]
pub struct Conv2DLayer {
weights: Tensor,
bias: Option<Tensor>,
weight_grad: Tensor,
bias_grad: Option<Tensor>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
activation: Activation,
use_bias: bool,
cached_input: Option<Tensor>,
cached_pre_activation: Option<Tensor>,
training: bool,
}
impl Conv2DLayer {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
activation: Activation,
use_bias: bool,
weight_init: WeightInit,
) -> Result<Self> {
let device = Device::auto_select()?;
Self::new_on_device(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
activation,
use_bias,
weight_init,
device,
)
}
pub fn new_on_device(
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
activation: Activation,
use_bias: bool,
weight_init: WeightInit,
device: Device,
) -> Result<Self> {
if in_channels == 0 || out_channels == 0 {
return Err(NnlError::config("Channel counts must be positive"));
}
if kernel_size.0 == 0 || kernel_size.1 == 0 {
return Err(NnlError::config("Kernel size must be positive"));
}
if stride.0 == 0 || stride.1 == 0 {
return Err(NnlError::config("Stride must be positive"));
}
if dilation.0 == 0 || dilation.1 == 0 {
return Err(NnlError::config("Dilation must be positive"));
}
let weight_shape = [out_channels, in_channels, kernel_size.0, kernel_size.1];
let mut weights = Tensor::zeros_on_device(&weight_shape, device.clone())?;
let fan_in = in_channels * kernel_size.0 * kernel_size.1;
let fan_out = out_channels * kernel_size.0 * kernel_size.1;
weight_init.initialize(&mut weights, fan_in, fan_out)?;
let weight_grad = Tensor::zeros_on_device(&weight_shape, device.clone())?;
let (bias, bias_grad) = if use_bias {
let mut bias_tensor = Tensor::zeros_on_device(&[out_channels], device.clone())?;
WeightInit::Zeros.initialize(&mut bias_tensor, 1, out_channels)?;
let bias_grad_tensor = Tensor::zeros_on_device(&[out_channels], device)?;
(Some(bias_tensor), Some(bias_grad_tensor))
} else {
(None, None)
};
Ok(Self {
weights,
bias,
weight_grad,
bias_grad,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
activation,
use_bias,
cached_input: None,
cached_pre_activation: None,
training: true,
})
}
pub fn in_channels(&self) -> usize {
self.in_channels
}
pub fn out_channels(&self) -> usize {
self.out_channels
}
pub fn kernel_size(&self) -> (usize, usize) {
self.kernel_size
}
pub fn stride(&self) -> (usize, usize) {
self.stride
}
pub fn padding(&self) -> (usize, usize) {
self.padding
}
fn calculate_output_size(&self, input_height: usize, input_width: usize) -> (usize, usize) {
let output_height =
(input_height + 2 * self.padding.0 - self.dilation.0 * (self.kernel_size.0 - 1) - 1)
/ self.stride.0
+ 1;
let output_width =
(input_width + 2 * self.padding.1 - self.dilation.1 * (self.kernel_size.1 - 1) - 1)
/ self.stride.1
+ 1;
(output_height, output_width)
}
fn conv2d_forward(&self, input: &Tensor) -> Result<Tensor> {
let input_shape = input.shape();
if input_shape.len() != 4 {
return Err(NnlError::tensor(
"Expected 4D input [batch, channels, height, width]",
));
}
let batch_size = input_shape[0];
let in_channels = input_shape[1];
let input_height = input_shape[2];
let input_width = input_shape[3];
if in_channels != self.in_channels {
return Err(NnlError::shape_mismatch(
&[self.in_channels],
&[in_channels],
));
}
let (output_height, output_width) = self.calculate_output_size(input_height, input_width);
let output_shape = [batch_size, self.out_channels, output_height, output_width];
let input_data = input.to_vec()?;
let weights_data = self.weights.to_vec()?;
let output_size = output_shape.iter().product::<usize>();
let mut output_data = vec![0.0; output_size];
let (kernel_h, kernel_w) = self.kernel_size;
let (stride_h, stride_w) = self.stride;
let (pad_h, pad_w) = self.padding;
let (dilation_h, dilation_w) = self.dilation;
for batch in 0..batch_size {
for out_ch in 0..self.out_channels {
for out_h in 0..output_height {
for out_w in 0..output_width {
let mut sum = 0.0;
for in_ch in 0..in_channels {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let in_h =
(out_h * stride_h + kh * dilation_h) as i32 - pad_h as i32;
let in_w =
(out_w * stride_w + kw * dilation_w) as i32 - pad_w as i32;
if in_h >= 0
&& in_w >= 0
&& (in_h as usize) < input_height
&& (in_w as usize) < input_width
{
let input_idx =
batch * in_channels * input_height * input_width
+ in_ch * input_height * input_width
+ (in_h as usize) * input_width
+ (in_w as usize);
let weight_idx = out_ch * in_channels * kernel_h * kernel_w
+ in_ch * kernel_h * kernel_w
+ kh * kernel_w
+ kw;
if input_idx < input_data.len()
&& weight_idx < weights_data.len()
{
sum += input_data[input_idx] * weights_data[weight_idx];
}
}
}
}
}
let output_idx = batch * self.out_channels * output_height * output_width
+ out_ch * output_height * output_width
+ out_h * output_width
+ out_w;
if output_idx < output_data.len() {
output_data[output_idx] = sum;
}
}
}
}
}
let mut output =
Tensor::from_slice_on_device(&output_data, &output_shape, input.device().clone())?;
if let Some(ref bias) = self.bias {
let bias_data = bias.to_vec()?;
let mut output_data = output.to_vec()?;
for batch in 0..batch_size {
for out_ch in 0..self.out_channels {
let bias_val = bias_data[out_ch];
for h in 0..output_height {
for w in 0..output_width {
let idx = batch * (self.out_channels * output_height * output_width)
+ out_ch * (output_height * output_width)
+ h * output_width
+ w;
if idx < output_data.len() {
output_data[idx] += bias_val;
}
}
}
}
}
output =
Tensor::from_slice_on_device(&output_data, &output_shape, input.device().clone())?;
}
Ok(output)
}
}
impl Layer for Conv2DLayer {
fn forward(&mut self, input: &Tensor, training: TrainingMode) -> Result<Tensor> {
self.training = matches!(training, TrainingMode::Training);
if self.training {
self.cached_input = Some(input.clone_data()?);
}
let conv_output = self.conv2d_forward(input)?;
if self.training {
self.cached_pre_activation = Some(conv_output.clone_data()?);
}
conv_output.activation(self.activation)
}
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor> {
let input = self
.cached_input
.as_ref()
.ok_or_else(|| NnlError::training("No cached input for backward pass"))?;
let input_shape = input.shape();
let grad_output_shape = grad_output.shape();
let input_data = input.to_vec()?;
let grad_output_data = grad_output.to_vec()?;
let weights_data = self.weights.to_vec()?;
let batch_size = input_shape[0];
let in_channels = input_shape[1];
let input_height = input_shape[2];
let input_width = input_shape[3];
let output_height = grad_output_shape[2];
let output_width = grad_output_shape[3];
let mut grad_input_data = vec![0.0; input_data.len()];
let grad_weights_data = vec![0.0; weights_data.len()];
let (kernel_h, kernel_w) = self.kernel_size;
let (stride_h, stride_w) = self.stride;
let (pad_h, pad_w) = self.padding;
let (dilation_h, dilation_w) = self.dilation;
for batch in 0..batch_size {
for in_ch in 0..in_channels {
for in_h in 0..input_height {
for in_w in 0..input_width {
let mut grad_sum = 0.0;
for out_ch in 0..self.out_channels {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let out_h = (in_h + pad_h) as i32 - (kh * dilation_h) as i32;
let out_w = (in_w + pad_w) as i32 - (kw * dilation_w) as i32;
if out_h >= 0
&& out_w >= 0
&& out_h % stride_h as i32 == 0
&& out_w % stride_w as i32 == 0
{
let out_h = (out_h / stride_h as i32) as usize;
let out_w = (out_w / stride_w as i32) as usize;
if out_h < output_height && out_w < output_width {
let grad_out_idx = batch
* self.out_channels
* output_height
* output_width
+ out_ch * output_height * output_width
+ out_h * output_width
+ out_w;
let weight_idx =
out_ch * in_channels * kernel_h * kernel_w
+ in_ch * kernel_h * kernel_w
+ kh * kernel_w
+ kw;
if grad_out_idx < grad_output_data.len()
&& weight_idx < weights_data.len()
{
grad_sum += grad_output_data[grad_out_idx]
* weights_data[weight_idx];
}
}
}
}
}
}
let grad_in_idx = batch * in_channels * input_height * input_width
+ in_ch * input_height * input_width
+ in_h * input_width
+ in_w;
if grad_in_idx < grad_input_data.len() {
grad_input_data[grad_in_idx] = grad_sum;
}
}
}
}
}
self.weight_grad = Tensor::from_slice_on_device(
&grad_weights_data,
self.weights.shape(),
input.device().clone(),
)?;
if self.bias.is_some() {
let mut grad_bias_data = vec![0.0; self.out_channels];
for out_ch in 0..self.out_channels {
let mut bias_grad = 0.0;
for batch in 0..batch_size {
for out_h in 0..output_height {
for out_w in 0..output_width {
let grad_out_idx =
batch * self.out_channels * output_height * output_width
+ out_ch * output_height * output_width
+ out_h * output_width
+ out_w;
if grad_out_idx < grad_output_data.len() {
bias_grad += grad_output_data[grad_out_idx];
}
}
}
}
grad_bias_data[out_ch] = bias_grad;
}
self.bias_grad = Some(Tensor::from_slice_on_device(
&grad_bias_data,
&[self.out_channels],
input.device().clone(),
)?);
}
let grad_input =
Tensor::from_slice_on_device(&grad_input_data, input_shape, input.device().clone())?;
Ok(grad_input)
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = vec![&self.weights];
if let Some(ref bias) = self.bias {
params.push(bias);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = vec![&mut self.weights];
if let Some(ref mut bias) = self.bias {
params.push(bias);
}
params
}
fn gradients(&self) -> Vec<&Tensor> {
let mut grads = vec![&self.weight_grad];
if let Some(ref bias_grad) = self.bias_grad {
grads.push(bias_grad);
}
grads
}
fn gradients_mut(&mut self) -> Vec<&mut Tensor> {
let mut grads = vec![&mut self.weight_grad];
if let Some(ref mut bias_grad) = self.bias_grad {
grads.push(bias_grad);
}
grads
}
fn zero_grad(&mut self) {
if let Err(e) = self.weight_grad.fill(0.0) {
eprintln!("Warning: Failed to zero weight gradients: {}", e);
}
if let Some(ref mut bias_grad) = self.bias_grad {
if let Err(e) = bias_grad.fill(0.0) {
eprintln!("Warning: Failed to zero bias gradients: {}", e);
}
}
}
fn name(&self) -> &str {
"Conv2D"
}
fn output_shape(&self, input_shape: &[usize]) -> Result<Vec<usize>> {
if input_shape.len() != 4 {
return Err(NnlError::tensor(
"Expected 4D input [batch, channels, height, width]",
));
}
let batch_size = input_shape[0];
let in_channels = input_shape[1];
let input_height = input_shape[2];
let input_width = input_shape[3];
if in_channels != self.in_channels {
return Err(NnlError::shape_mismatch(
&[self.in_channels],
&[in_channels],
));
}
let (output_height, output_width) = self.calculate_output_size(input_height, input_width);
Ok(vec![
batch_size,
self.out_channels,
output_height,
output_width,
])
}
fn set_training(&mut self, training: bool) {
self.training = training;
}
fn training(&self) -> bool {
self.training
}
fn to_device(&mut self, device: Device) -> Result<()> {
self.weights = self.weights.to_device(device.clone())?;
self.weight_grad = self.weight_grad.to_device(device.clone())?;
if let Some(ref bias) = self.bias {
self.bias = Some(bias.to_device(device.clone())?);
}
if let Some(ref bias_grad) = self.bias_grad {
self.bias_grad = Some(bias_grad.to_device(device)?);
}
Ok(())
}
fn clone_layer(&self) -> Result<Box<dyn Layer>> {
let mut cloned = Conv2DLayer::new(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.activation,
self.use_bias,
WeightInit::Zeros, )?;
cloned.weights = self.weights.clone_data()?;
if let Some(ref bias) = self.bias {
cloned.bias = Some(bias.clone_data()?);
}
cloned.training = self.training;
Ok(Box::new(cloned))
}
}
impl fmt::Display for Conv2DLayer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Conv2D({} → {}, kernel={}×{}, stride={}×{}, padding={}×{}, {})",
self.in_channels,
self.out_channels,
self.kernel_size.0,
self.kernel_size.1,
self.stride.0,
self.stride.1,
self.padding.0,
self.padding.1,
self.activation
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::activations::Activation;
use crate::layers::WeightInit;
#[test]
fn test_conv2d_creation() {
let layer = Conv2DLayer::new(
3,
64,
(3, 3),
(1, 1),
(1, 1),
(1, 1),
Activation::ReLU,
true,
WeightInit::Xavier,
);
assert!(layer.is_ok());
let layer = layer.unwrap();
assert_eq!(layer.in_channels(), 3);
assert_eq!(layer.out_channels(), 64);
assert_eq!(layer.kernel_size(), (3, 3));
assert_eq!(layer.stride(), (1, 1));
assert_eq!(layer.padding(), (1, 1));
}
#[test]
fn test_conv2d_output_shape() {
let layer = Conv2DLayer::new(
3,
64,
(3, 3),
(1, 1),
(1, 1),
(1, 1),
Activation::ReLU,
true,
WeightInit::Xavier,
)
.unwrap();
let input_shape = vec![1, 3, 32, 32]; let output_shape = layer.output_shape(&input_shape).unwrap();
assert_eq!(output_shape, vec![1, 64, 32, 32]); }
#[test]
fn test_conv2d_forward() {
let mut layer = Conv2DLayer::new(
3,
64,
(3, 3),
(1, 1),
(1, 1),
(1, 1),
Activation::ReLU,
true,
WeightInit::Xavier,
)
.unwrap();
let input = Tensor::randn(&[1, 3, 32, 32]).unwrap();
let output = layer.forward(&input, TrainingMode::Inference);
assert!(output.is_ok());
let output = output.unwrap();
assert_eq!(output.shape(), &[1, 64, 32, 32]);
}
#[test]
fn test_conv2d_parameters() {
let layer = Conv2DLayer::new(
3,
64,
(3, 3),
(1, 1),
(1, 1),
(1, 1),
Activation::ReLU,
true,
WeightInit::Xavier,
)
.unwrap();
let params = layer.parameters();
assert_eq!(params.len(), 2);
assert_eq!(params[0].shape(), &[64, 3, 3, 3]);
assert_eq!(params[1].shape(), &[64]);
assert_eq!(layer.num_parameters(), 64 * 3 * 3 * 3 + 64);
}
#[test]
fn test_conv2d_without_bias() {
let layer = Conv2DLayer::new(
3,
64,
(3, 3),
(1, 1),
(1, 1),
(1, 1),
Activation::ReLU,
false,
WeightInit::Xavier,
)
.unwrap();
let params = layer.parameters();
assert_eq!(params.len(), 1); assert_eq!(layer.num_parameters(), 64 * 3 * 3 * 3);
}
#[test]
fn test_conv2d_invalid_parameters() {
assert!(
Conv2DLayer::new(
0,
64,
(3, 3),
(1, 1),
(1, 1),
(1, 1),
Activation::ReLU,
true,
WeightInit::Xavier
)
.is_err()
);
assert!(
Conv2DLayer::new(
3,
64,
(0, 3),
(1, 1),
(1, 1),
(1, 1),
Activation::ReLU,
true,
WeightInit::Xavier
)
.is_err()
);
assert!(
Conv2DLayer::new(
3,
64,
(3, 3),
(0, 1),
(1, 1),
(1, 1),
Activation::ReLU,
true,
WeightInit::Xavier
)
.is_err()
);
}
}