ferrum_kernels/linear.rs
1//! `Linear<B>` trait — weight-bearing projection abstraction.
2//!
3//! Lives in ferrum-kernels alongside `Backend` because:
4//! 1. `Backend::layer_forward_fused` and other "standard transformer layer"
5//! helpers want to accept `&dyn Linear<Self>` as their projection
6//! parameter, so the trait must be visible here.
7//! 2. Model code in `ferrum-models` depends on both ferrum-kernels and
8//! ferrum-quantization, so keeping the trait in kernels avoids any
9//! circular dependency between kernels and quantization.
10//!
11//! Concrete implementations (DenseLinear, GptqLinear, AwqLinear, GgufLinear)
12//! live in `ferrum-quantization`, which depends on `ferrum-kernels` for this
13//! trait and for the `Backend` it parameterises over.
14
15use crate::backend::Backend;
16
17/// A weight-bearing linear projection.
18///
19/// `forward` computes `out[m, out_features] = input[m, in_features] @ W^T`.
20/// Implementations are responsible for calling the right backend kernel
21/// (`B::gemm` for dense, `B::gemm_quant` for quantized variants).
22pub trait Linear<B: Backend>: Send + Sync {
23 fn in_features(&self) -> usize;
24 fn out_features(&self) -> usize;
25
26 /// Append GEMM work onto `ctx`. Caller flushes the context when results
27 /// must be materialised.
28 fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize);
29}