use crate::nn::init::Initializer;
use crate::nn::module::Module;
use crate::tensor::{GraphContext, Tensor};
use std::cell::RefCell;
use std::rc::Rc;
#[derive(Debug, Clone)]
pub struct Conv2dConfig {
pub in_channels: usize,
pub out_channels: usize,
pub kernel_size: (usize, usize),
pub stride: (usize, usize),
pub padding: (usize, usize),
pub dilation: (usize, usize),
pub groups: usize,
pub bias: bool,
}
impl Default for Conv2dConfig {
fn default() -> Self {
Self {
in_channels: 1,
out_channels: 1,
kernel_size: (3, 3),
stride: (1, 1),
padding: (0, 0),
dilation: (1, 1),
groups: 1,
bias: true,
}
}
}
impl Conv2dConfig {
pub fn new(in_channels: usize, out_channels: usize, kernel_size: (usize, usize)) -> Self {
Self {
in_channels,
out_channels,
kernel_size,
..Default::default()
}
}
pub fn with_stride(mut self, stride: (usize, usize)) -> Self {
self.stride = stride;
self
}
pub fn with_padding(mut self, padding: (usize, usize)) -> Self {
self.padding = padding;
self
}
pub fn with_dilation(mut self, dilation: (usize, usize)) -> Self {
self.dilation = dilation;
self
}
pub fn with_groups(mut self, groups: usize) -> Self {
self.groups = groups;
self
}
pub fn with_bias(mut self, bias: bool) -> Self {
self.bias = bias;
self
}
}
pub struct Conv2d {
pub weight: Tensor,
pub bias: Option<Tensor>,
pub config: Conv2dConfig,
}
impl Conv2d {
pub fn new(
context: &Rc<RefCell<GraphContext>>,
name: &str,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
) -> Self {
let config = Conv2dConfig::new(in_channels, out_channels, kernel_size);
Self::from_config(context, name, config)
}
pub fn from_config(
context: &Rc<RefCell<GraphContext>>,
name: &str,
config: Conv2dConfig,
) -> Self {
let c_in_per_group = config.in_channels / config.groups;
let weight_shape = vec![
config.out_channels,
c_in_per_group,
config.kernel_size.0,
config.kernel_size.1,
];
let weight = Tensor::new_parameter_with_shape(
context,
&format!("{}.weight", name),
weight_shape,
Initializer::KaimingUniform,
);
let bias = if config.bias {
Some(Tensor::new_parameter_with_shape(
context,
&format!("{}.bias", name),
vec![config.out_channels],
Initializer::Zeros,
))
} else {
None
};
Self {
weight,
bias,
config,
}
}
pub fn with_stride(mut self, stride: (usize, usize)) -> Self {
self.config.stride = stride;
self
}
pub fn with_padding(mut self, padding: (usize, usize)) -> Self {
self.config.padding = padding;
self
}
pub fn with_dilation(mut self, dilation: (usize, usize)) -> Self {
self.config.dilation = dilation;
self
}
pub fn with_groups(mut self, groups: usize) -> Self {
self.config.groups = groups;
self
}
}
impl Module for Conv2d {
fn forward(&self, inputs: &Tensor) -> Tensor {
inputs.conv2d(
&self.weight,
self.bias.as_ref(),
self.config.stride,
self.config.padding,
self.config.dilation,
self.config.groups,
)
}
fn parameters(&self) -> Vec<Tensor> {
let mut params = vec![self.weight.clone()];
if let Some(ref bias) = self.bias {
params.push(bias.clone());
}
params
}
}
pub struct ConvTranspose2d {
pub weight: Tensor,
pub bias: Option<Tensor>,
pub stride: (usize, usize),
pub padding: (usize, usize),
pub output_padding: (usize, usize),
pub dilation: (usize, usize),
pub groups: usize,
}
impl ConvTranspose2d {
pub fn new(
context: &Rc<RefCell<GraphContext>>,
name: &str,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
) -> Self {
let weight = Tensor::new_parameter_with_shape(
context,
&format!("{}.weight", name),
vec![in_channels, out_channels, kernel_size.0, kernel_size.1],
Initializer::KaimingUniform,
);
let bias = Some(Tensor::new_parameter_with_shape(
context,
&format!("{}.bias", name),
vec![out_channels],
Initializer::Zeros,
));
Self {
weight,
bias,
stride: (1, 1),
padding: (0, 0),
output_padding: (0, 0),
dilation: (1, 1),
groups: 1,
}
}
pub fn with_stride(mut self, stride: (usize, usize)) -> Self {
self.stride = stride;
self
}
pub fn with_padding(mut self, padding: (usize, usize)) -> Self {
self.padding = padding;
self
}
pub fn with_output_padding(mut self, output_padding: (usize, usize)) -> Self {
self.output_padding = output_padding;
self
}
pub fn without_bias(mut self) -> Self {
self.bias = None;
self
}
}
impl Module for ConvTranspose2d {
fn forward(&self, inputs: &Tensor) -> Tensor {
inputs.conv_transpose2d(
&self.weight,
self.bias.as_ref(),
self.stride,
self.padding,
self.output_padding,
self.dilation,
self.groups,
)
}
fn parameters(&self) -> Vec<Tensor> {
let mut params = vec![self.weight.clone()];
if let Some(ref bias) = self.bias {
params.push(bias.clone());
}
params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conv2d_creation() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let conv = Conv2d::new(&context, "conv1", 3, 64, (3, 3))
.with_padding((1, 1))
.with_stride((1, 1));
assert_eq!(conv.config.in_channels, 3);
assert_eq!(conv.config.out_channels, 64);
assert_eq!(conv.config.kernel_size, (3, 3));
assert_eq!(conv.config.padding, (1, 1));
assert!(conv.bias.is_some());
}
#[test]
fn test_conv2d_forward() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let input = Tensor::new_input(&context, "input");
let conv = Conv2d::new(&context, "conv1", 3, 64, (3, 3));
let _ = conv.forward(&input);
let graph = context.borrow().main_graph().clone();
assert!(graph.nodes.len() > 2); }
#[test]
fn test_conv2d_no_bias() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let config = Conv2dConfig::new(3, 64, (3, 3)).with_bias(false);
let conv = Conv2d::from_config(&context, "conv1", config);
assert!(conv.bias.is_none());
assert_eq!(conv.parameters().len(), 1);
}
#[test]
fn test_conv_transpose2d() {
let context = Rc::new(RefCell::new(GraphContext::new()));
let input = Tensor::new_input(&context, "input");
let deconv = ConvTranspose2d::new(&context, "deconv1", 64, 3, (4, 4))
.with_stride((2, 2))
.with_padding((1, 1));
let _ = deconv.forward(&input);
assert_eq!(deconv.parameters().len(), 2); }
}