use ferrum_kernels::backend::Backend;
use crate::traits::Linear;
pub struct DenseLinear<B: Backend> {
weight: B::Buffer,
bias: Option<B::Buffer>,
in_features: usize,
out_features: usize,
}
impl<B: Backend> DenseLinear<B> {
pub fn from_rows(weight_row_major: &[f32], out_features: usize, in_features: usize) -> Self {
debug_assert_eq!(
weight_row_major.len(),
out_features * in_features,
"DenseLinear weight length mismatch"
);
let weight = B::from_slice(weight_row_major);
Self {
weight,
bias: None,
in_features,
out_features,
}
}
pub fn from_rows_with_bias(
weight_row_major: &[f32],
bias: &[f32],
out_features: usize,
in_features: usize,
) -> Self {
debug_assert_eq!(bias.len(), out_features, "DenseLinear bias length mismatch");
Self {
weight: B::from_slice(weight_row_major),
bias: Some(B::from_slice(bias)),
in_features,
out_features,
}
}
pub fn from_buffer(weight: B::Buffer, out_features: usize, in_features: usize) -> Self {
Self {
weight,
bias: None,
in_features,
out_features,
}
}
pub fn with_bias(mut self, bias: B::Buffer) -> Self {
self.bias = Some(bias);
self
}
pub fn weight(&self) -> &B::Buffer {
&self.weight
}
pub fn bias(&self) -> Option<&B::Buffer> {
self.bias.as_ref()
}
}
impl<B: Backend> Linear<B> for DenseLinear<B> {
fn in_features(&self) -> usize {
self.in_features
}
fn out_features(&self) -> usize {
self.out_features
}
fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
B::gemm(
ctx,
input,
&self.weight,
out,
m,
self.out_features,
self.in_features,
);
if let Some(bias) = &self.bias {
B::add_bias(ctx, out, bias, m, self.out_features);
}
}
}