use zyx::{DType, Tensor, ZyxError};
use zyx_derive::Module;
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass(get_all, set_all))]
pub struct Linear {
pub weight: Tensor,
pub bias: Option<Tensor>,
}
impl Linear {
pub fn new(
in_features: u64,
out_features: u64,
bias: bool,
dtype: DType,
) -> Result<Linear, ZyxError> {
let l = -(1.0 / (in_features as f32)).sqrt();
let u = (1.0 / (in_features as f32)).sqrt();
Ok(Linear {
weight: Tensor::uniform([out_features, in_features], l..u)?.cast(dtype),
bias: if bias {
Some(Tensor::uniform([out_features], l..u)?.cast(dtype))
} else {
None
},
})
}
pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
let x = x.into().dot(self.weight.t())?;
if let Some(bias) = &self.bias {
return Ok(x + bias);
}
Ok(x)
}
}
#[test]
fn linear() -> Result<(), ZyxError> {
let l0 = Linear::new(4, 16, true, DType::F32)?;
println!("{}\n{}", l0.weight, l0.bias.as_ref().unwrap());
let x = Tensor::randn([8, 4], DType::F32)?;
let y = l0.forward(x)?.relu();
println!("{y}");
Ok(())
}