Skip to main content

ferrum_kernels/
linear.rs

1//! `Linear<B>` trait — weight-bearing projection abstraction.
2//!
3//! Lives in ferrum-kernels alongside `Backend` because:
4//!   1. `Backend::layer_forward_fused` and other "standard transformer layer"
5//!      helpers want to accept `&dyn Linear<Self>` as their projection
6//!      parameter, so the trait must be visible here.
7//!   2. Model code in `ferrum-models` depends on both ferrum-kernels and
8//!      ferrum-quantization, so keeping the trait in kernels avoids any
9//!      circular dependency between kernels and quantization.
10//!
11//! Concrete implementations (DenseLinear, GptqLinear, AwqLinear, GgufLinear)
12//! live in `ferrum-quantization`, which depends on `ferrum-kernels` for this
13//! trait and for the `Backend` it parameterises over.
14
15use crate::backend::Backend;
16
17/// A weight-bearing linear projection.
18///
19/// `forward` computes `out[m, out_features] = input[m, in_features] @ W^T`.
20/// Implementations are responsible for calling the right backend kernel
21/// (`B::gemm` for dense, `B::gemm_quant` for quantized variants).
22pub trait Linear<B: Backend>: Send + Sync {
23    fn in_features(&self) -> usize;
24    fn out_features(&self) -> usize;
25
26    /// Append GEMM work onto `ctx`. Caller flushes the context when results
27    /// must be materialised.
28    fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize);
29}