use svod_dtype::DType;
use crate::Tensor;
use crate::nn::Layer;
type Result<T> = crate::Result<T>;
pub struct Linear {
pub weight: Tensor,
pub bias: Tensor,
}
impl Linear {
pub fn new(weight: Tensor, bias: Tensor) -> Self {
Self { weight, bias }
}
pub fn with_dims(in_features: usize, out_features: usize, dtype: DType) -> Self {
let weight_data: Vec<f32> = (0..in_features * out_features).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
let weight = Tensor::from_slice(&weight_data)
.try_reshape([out_features as isize, in_features as isize])
.expect("linear weight reshape failed");
let bias = Tensor::full(&[out_features], 0.0, dtype).expect("linear bias creation failed");
Self { weight, bias }
}
}
impl Layer for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
x.linear().weight(&self.weight).bias(&self.bias).call()
}
}