use svod_dtype::DType;
use crate::Tensor;
use crate::nn::Layer;
type Result<T> = crate::Result<T>;
pub struct Conv1d {
pub weight: Tensor,
pub bias: Option<Tensor>,
pub stride: usize,
pub padding: (isize, isize),
}
impl Conv1d {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias, stride: 1, padding: (0, 0) }
}
pub fn with_dims(in_channels: usize, out_channels: usize, kernel: usize, dtype: DType) -> Self {
let weight_data: Vec<f32> =
(0..in_channels * out_channels * kernel).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
let weight = Tensor::from_slice(&weight_data)
.try_reshape([out_channels as isize, in_channels as isize, kernel as isize])
.expect("conv1d weight reshape failed");
let bias = Tensor::full(&[out_channels], 0.0, dtype).expect("conv1d bias creation failed");
Self { weight, bias: Some(bias), stride: 1, padding: (0, 0) }
}
pub fn with_stride(mut self, stride: usize) -> Self {
self.stride = stride;
self
}
pub fn with_padding(mut self, padding: (isize, isize)) -> Self {
self.padding = padding;
self
}
}
impl Layer for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
x.conv2d()
.weight(&self.weight)
.maybe_bias(self.bias.as_ref())
.stride(&[self.stride])
.padding(&[self.padding])
.call()
}
}