use crate::init::{self, Init, DEFAULT_KAIMING_NORMAL};
use rai_core::{AsDevice, Shape, Tensor, Type};
use rai_derive::Module;
use std::fmt::Debug;
#[derive(Clone, Debug)]
pub struct Conv1dConfig {
pub padding: usize,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
}
impl Default for Conv1dConfig {
fn default() -> Self {
Self {
padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}
#[derive(Clone, Debug, Module)]
#[module(crate = rai_core)]
pub struct Conv1d {
weight: Tensor,
bias: Option<Tensor>,
#[param(skip)]
config: Conv1dConfig,
}
pub trait IntoConv1dConfig: Debug {
fn into_conv1d_config(self) -> Conv1dConfig;
}
impl IntoConv1dConfig for Conv1dConfig {
fn into_conv1d_config(self) -> Conv1dConfig {
self
}
}
impl Conv1d {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: impl IntoConv1dConfig,
has_bias: bool,
dtype: impl Type,
device: impl AsDevice,
) -> Self {
let device = device.device();
let config = config.into_conv1d_config();
let weight = Tensor::rand(
[out_channels, in_channels / config.groups, kernel_size],
dtype,
device,
);
let bias = if has_bias {
Some(Tensor::rand([out_channels], dtype, device))
} else {
None
};
Self {
weight,
bias,
config,
}
}
pub fn fwd(&self, x: &Tensor) -> Tensor {
let x = x.conv1d(
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
);
match &self.bias {
Some(bias) => {
let bias = bias.reshape([1, bias.shape_at(0), 1, 1]);
x + bias
}
None => x,
}
}
}
#[derive(Clone, Debug)]
pub struct Conv2dConfig {
pub padding: [usize; 2],
pub stride: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
}
impl Default for Conv2dConfig {
fn default() -> Self {
Self {
padding: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
}
}
}
#[derive(Clone, Debug, Module)]
#[module(crate = rai_core)]
pub struct Conv2d {
weight: Tensor,
bias: Option<Tensor>,
#[param(skip)]
config: Conv2dConfig,
}
pub trait IntoConv2dConfig: Debug {
fn into_conv2d_config(self) -> Conv2dConfig;
}
impl IntoConv2dConfig for Conv2dConfig {
fn into_conv2d_config(self) -> Conv2dConfig {
self
}
}
impl Conv2d {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: impl IntoConv2dConfig,
has_bias: bool,
dtype: impl Type,
device: impl AsDevice,
) -> Self {
let bound = 1. / (in_channels as f64).sqrt();
let bias_init = match has_bias {
true => Some(init::Uniform::new(-bound, bound)),
false => None,
};
Self::new_with_init(
in_channels,
out_channels,
kernel_size,
config,
dtype,
device,
DEFAULT_KAIMING_NORMAL,
bias_init,
)
}
#[allow(clippy::too_many_arguments)]
pub fn new_with_init(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: impl IntoConv2dConfig,
dtype: impl Type,
device: impl AsDevice,
weight_init: impl Init,
bias_init: Option<impl Init>,
) -> Self {
let device = device.device();
let config = config.into_conv2d_config();
let weight = weight_init.new_tensor(
[
out_channels,
in_channels / config.groups,
kernel_size,
kernel_size,
],
dtype,
device,
);
let bias = bias_init.map(|init| init.new_tensor([out_channels], dtype, device));
Self {
weight,
bias,
config,
}
}
pub fn fwd(&self, x: &Tensor) -> Tensor {
let x = x.conv2d(
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
);
match &self.bias {
Some(bias) => {
let bias = bias.reshape([1, bias.shape_at(0), 1, 1]);
x + bias
}
None => x,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvTranspose1dConfig {
pub padding: usize,
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
}
impl Default for ConvTranspose1dConfig {
fn default() -> Self {
Self {
padding: 0,
output_padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}
#[derive(Clone, Debug, Module)]
#[module(crate = rai_core)]
pub struct ConvTranspose1d {
weight: Tensor,
bias: Option<Tensor>,
#[param(skip)]
config: ConvTranspose1dConfig,
}
pub trait IntoConvTranspose1dConfig: Debug {
fn into_conv_transpose1d_config(self) -> ConvTranspose1dConfig;
}
impl IntoConvTranspose1dConfig for ConvTranspose1dConfig {
fn into_conv_transpose1d_config(self) -> ConvTranspose1dConfig {
self
}
}
impl ConvTranspose1d {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: impl IntoConvTranspose1dConfig,
has_bias: bool,
dtype: impl Type,
device: impl AsDevice,
) -> Self {
let device = device.device();
let config = config.into_conv_transpose1d_config();
let weight = Tensor::rand(
[in_channels, out_channels / config.groups, kernel_size],
dtype,
device,
);
let bias = if has_bias {
Some(Tensor::rand([out_channels], dtype, device))
} else {
None
};
Self {
weight,
bias,
config,
}
}
pub fn fwd(&self, x: &Tensor) -> Tensor {
let x = x.conv_transpose1d(
&self.weight,
self.config.padding,
self.config.output_padding,
self.config.stride,
self.config.dilation,
self.config.groups,
);
match &self.bias {
Some(bias) => {
let bias = bias.reshape([1, bias.shape_at(0), 1, 1]);
x + bias
}
None => x,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ConvTranspose2dConfig {
pub padding: [usize; 2],
pub output_padding: [usize; 2],
pub stride: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
}
impl Default for ConvTranspose2dConfig {
fn default() -> Self {
Self {
padding: [0, 0],
output_padding: [0, 0],
stride: [1, 1],
dilation: [1, 1],
groups: 1,
}
}
}
#[derive(Clone, Debug, Module)]
#[module(crate = rai_core)]
pub struct ConvTranspose2d {
weight: Tensor,
bias: Option<Tensor>,
#[param(skip)]
config: ConvTranspose2dConfig,
}
pub trait IntoConvTranspose2dConfig: Debug {
fn into_conv_transpose2d_config(self) -> ConvTranspose2dConfig;
}
impl IntoConvTranspose2dConfig for ConvTranspose2dConfig {
fn into_conv_transpose2d_config(self) -> ConvTranspose2dConfig {
self
}
}
impl ConvTranspose2d {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: impl IntoConvTranspose2dConfig,
has_bias: bool,
dtype: impl Type,
device: impl AsDevice,
) -> Self {
let device = device.device();
let config = config.into_conv_transpose2d_config();
let weight = Tensor::rand(
[
in_channels,
out_channels / config.groups,
kernel_size,
kernel_size,
],
dtype,
device,
);
let bias = if has_bias {
Some(Tensor::rand([out_channels], dtype, device))
} else {
None
};
Self {
weight,
bias,
config,
}
}
pub fn fwd(&self, x: &Tensor) -> Tensor {
let x = x.conv_transpose2d(
&self.weight,
self.config.padding,
self.config.output_padding,
self.config.stride,
self.config.dilation,
self.config.groups,
);
match &self.bias {
Some(bias) => {
let bias = bias.reshape([1, bias.shape_at(0), 1, 1]);
x + bias
}
None => x,
}
}
}