use zyx::{DType, IntoShape, Tensor, ZyxError};
use zyx_derive::Module;
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct Conv2d {
stride: Vec<u64>,
dilation: Vec<u64>,
groups: u64,
padding: Vec<u64>,
pub weight: Tensor,
pub bias: Option<Tensor>,
}
impl Conv2d {
pub fn new(
in_channels: u64,
out_channels: u64,
kernel_size: impl IntoShape,
stride: impl IntoShape,
padding: impl IntoShape,
dilation: impl IntoShape,
groups: u64,
bias: bool,
dtype: DType,
) -> Result<Self, ZyxError> {
let mut kernel_size: Vec<u64> = kernel_size.into_shape().collect();
kernel_size.push(2);
let scale = 1f32 / ((in_channels * kernel_size.iter().product::<u64>()) as f32).sqrt();
let mut weight_shape = vec![out_channels, in_channels / groups];
weight_shape.extend(kernel_size);
Ok(Conv2d {
stride: stride.into_shape().collect(),
dilation: dilation.into_shape().collect(),
groups,
padding: padding.into_shape().collect(),
weight: Tensor::uniform(weight_shape, -scale..scale)?.cast(dtype),
bias: if bias {
Some(Tensor::uniform(out_channels, -scale..scale)?.cast(dtype))
} else {
None
},
})
}
pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
x.into().conv(
&self.weight,
self.bias.as_ref(),
self.groups,
&self.stride,
&self.dilation,
&self.padding,
)
}
}