Skip to main content

ferrum_interfaces/
kernel_ops.rs

1//! Kernel backend abstraction layer for LLM-specific fused operations.
2//!
3//! This module defines a mid-level abstraction between raw `KernelExecutor`
4//! (too low-level: grid/block sizes) and `TensorOps` (too high-level: no
5//! LLM-specific fused ops). It enables pluggable CUDA/Metal/CPU backends
6//! through six focused sub-traits composed into one umbrella `KernelOps`.
7
8use crate::TensorRef;
9use ferrum_types::Result;
10
11// ---------------------------------------------------------------------------
12// Configuration structs
13// ---------------------------------------------------------------------------
14
15/// Rotary position embedding configuration.
16#[derive(Debug, Clone)]
17pub struct RoPEConfig {
18    /// Dimensionality of each attention head.
19    pub head_dim: usize,
20    /// Maximum sequence length the cache covers.
21    pub max_seq_len: usize,
22    /// Base frequency (default 10000.0 for standard RoPE).
23    pub theta: f32,
24}
25
26/// Parameters describing a single attention call.
27#[derive(Debug, Clone)]
28pub struct AttentionParams {
29    pub num_heads: usize,
30    pub num_kv_heads: usize,
31    pub head_dim: usize,
32    /// Softmax scale (typically `1 / sqrt(head_dim)`).
33    pub softmax_scale: f32,
34    /// Whether to apply a causal mask.
35    pub causal: bool,
36}
37
38/// Quantization scheme descriptor for quantized linear ops.
39#[derive(Debug, Clone)]
40pub enum QuantScheme {
41    /// 4-bit quantization with group size (e.g. Q4_0 uses group_size=32).
42    Q4_0 { group_size: usize },
43    /// 8-bit quantization.
44    Q8_0 { group_size: usize },
45}
46
47/// Sampling parameters for GPU-side token sampling.
48#[derive(Debug, Clone)]
49pub struct SamplingParams {
50    pub temperature: f32,
51    pub top_k: usize,
52    pub top_p: f32,
53    pub repetition_penalty: f32,
54    /// Token IDs that have appeared in context (for repetition penalty).
55    pub repetition_token_ids: Vec<u32>,
56    /// Frequency count for each token in `repetition_token_ids`.
57    pub repetition_token_freqs: Vec<u32>,
58    /// Per-step RNG seed.
59    pub rng_seed: u32,
60}
61
62// ---------------------------------------------------------------------------
63// Sub-traits
64// ---------------------------------------------------------------------------
65
66/// Normalization operations.
67pub trait NormOps: Send + Sync {
68    /// RMS normalization: `x / rms(x) * weight`.
69    fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef>;
70
71    /// Fused RMS normalization with residual add:
72    /// `output = rms_norm(input + residual, weight, eps)`.
73    /// Returns `(normed_output, updated_residual)`.
74    fn rms_norm_residual(
75        &self,
76        input: &TensorRef,
77        residual: &TensorRef,
78        weight: &TensorRef,
79        eps: f32,
80    ) -> Result<(TensorRef, TensorRef)> {
81        // Default: add then norm separately.
82        let _ = (input, residual, weight, eps);
83        Err(ferrum_types::FerrumError::unsupported(
84            "rms_norm_residual not implemented",
85        ))
86    }
87}
88
89/// Positional encoding operations.
90pub trait PositionOps: Send + Sync {
91    /// Apply rotary position embedding to a Q or K tensor.
92    ///
93    /// `x` shape: `[batch, seq_len, num_heads, head_dim]`
94    /// `position_ids`: position indices for each token in the sequence.
95    /// `cos_cache` / `sin_cache`: precomputed `[max_seq_len, head_dim/2]`.
96    fn rotary_embedding(
97        &self,
98        x: &TensorRef,
99        cos_cache: &TensorRef,
100        sin_cache: &TensorRef,
101        position_ids: &[usize],
102    ) -> Result<TensorRef>;
103}
104
105/// Attention operations.
106pub trait AttentionOps: Send + Sync {
107    /// Standard multi-head / grouped-query attention.
108    ///
109    /// * `q` — `[batch, seq_q, num_heads, head_dim]`
110    /// * `k` — `[batch, seq_kv, num_kv_heads, head_dim]`
111    /// * `v` — `[batch, seq_kv, num_kv_heads, head_dim]`
112    ///
113    /// Returns attention output `[batch, seq_q, num_heads, head_dim]`.
114    fn attention(
115        &self,
116        q: &TensorRef,
117        k: &TensorRef,
118        v: &TensorRef,
119        params: &AttentionParams,
120    ) -> Result<TensorRef>;
121
122    /// Paged attention for KV-cache-based decode.
123    ///
124    /// Default returns unsupported — backends opt in.
125    fn paged_attention(
126        &self,
127        _q: &TensorRef,
128        _k_cache: &TensorRef,
129        _v_cache: &TensorRef,
130        _block_table: &[u32],
131        _params: &AttentionParams,
132    ) -> Result<TensorRef> {
133        Err(ferrum_types::FerrumError::unsupported(
134            "paged_attention not implemented",
135        ))
136    }
137}
138
139/// Activation function operations (including fused variants).
140pub trait ActivationOps: Send + Sync {
141    /// Fused SiLU-multiply: `silu(gate) * up`.
142    ///
143    /// This is the SwiGLU building block used in LLaMA/Qwen MLPs.
144    fn silu_mul(&self, gate: &TensorRef, up: &TensorRef) -> Result<TensorRef>;
145
146    /// GELU activation.
147    fn gelu(&self, input: &TensorRef) -> Result<TensorRef>;
148}
149
150/// Linear / matrix-multiply operations.
151pub trait LinearOps: Send + Sync {
152    /// Dense linear projection (no bias): `input @ weight^T`.
153    ///
154    /// * `input`  — `[..., in_features]`
155    /// * `weight` — `[out_features, in_features]`
156    fn linear(&self, input: &TensorRef, weight: &TensorRef) -> Result<TensorRef>;
157
158    /// Quantized linear projection.
159    ///
160    /// `packed_weight` is backend-specific packed data (e.g. Q4_0 blocks).
161    fn quantized_linear(
162        &self,
163        _input: &TensorRef,
164        _packed_weight: &TensorRef,
165        _scheme: &QuantScheme,
166    ) -> Result<TensorRef> {
167        Err(ferrum_types::FerrumError::unsupported(
168            "quantized_linear not implemented",
169        ))
170    }
171}
172
173/// Token sampling operations (GPU-side when possible).
174pub trait SamplingOps: Send + Sync {
175    /// Sample a single token from logits using the full sampling pipeline.
176    ///
177    /// `logits` shape: `[vocab_size]` or `[1, vocab_size]` (last-token logits).
178    fn sample_token(&self, logits: &TensorRef, params: &SamplingParams) -> Result<u32>;
179
180    /// Greedy argmax over the last dimension.
181    fn argmax(&self, logits: &TensorRef) -> Result<u32>;
182}
183
184// ---------------------------------------------------------------------------
185// Umbrella trait
186// ---------------------------------------------------------------------------
187
188/// Unified kernel operations interface.
189///
190/// Backends implement whichever sub-traits they support and return `None` for
191/// the rest. Callers use `KernelOpsDispatch` (below) to get automatic fallback
192/// to `TensorOps` when a sub-trait is unavailable.
193pub trait KernelOps: Send + Sync {
194    fn norm_ops(&self) -> Option<&dyn NormOps> {
195        None
196    }
197    fn position_ops(&self) -> Option<&dyn PositionOps> {
198        None
199    }
200    fn attention_ops(&self) -> Option<&dyn AttentionOps> {
201        None
202    }
203    fn activation_ops(&self) -> Option<&dyn ActivationOps> {
204        None
205    }
206    fn linear_ops(&self) -> Option<&dyn LinearOps> {
207        None
208    }
209    fn sampling_ops(&self) -> Option<&dyn SamplingOps> {
210        None
211    }
212
213    /// Human-readable backend identifier (e.g. `"candle"`, `"metal"`, `"cuda"`).
214    fn backend_name(&self) -> &str;
215}
216
217// ---------------------------------------------------------------------------
218// Dispatch helper (Step 3)
219// ---------------------------------------------------------------------------
220
221/// Dispatch wrapper that tries `KernelOps` first, then falls back to
222/// `TensorOps` for operations that have a `TensorOps` equivalent.
223///
224/// This enables gradual migration: callers use the dispatch without caring
225/// which path actually runs.
226pub struct KernelOpsDispatch<'a> {
227    kernel_ops: Option<&'a dyn KernelOps>,
228    tensor_ops: &'a dyn crate::TensorOps,
229}
230
231impl<'a> KernelOpsDispatch<'a> {
232    pub fn new(
233        kernel_ops: Option<&'a dyn KernelOps>,
234        tensor_ops: &'a dyn crate::TensorOps,
235    ) -> Self {
236        Self {
237            kernel_ops,
238            tensor_ops,
239        }
240    }
241
242    /// RMS norm: prefer `KernelOps::NormOps`, fall back to `TensorOps::rms_norm`.
243    pub fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef> {
244        if let Some(ko) = self.kernel_ops {
245            if let Some(norm) = ko.norm_ops() {
246                return norm.rms_norm(input, weight, eps);
247            }
248        }
249        self.tensor_ops.rms_norm(input, weight, eps)
250    }
251
252    /// GELU: prefer `KernelOps::ActivationOps`, fall back to `TensorOps::gelu`.
253    pub fn gelu(&self, input: &TensorRef) -> Result<TensorRef> {
254        if let Some(ko) = self.kernel_ops {
255            if let Some(act) = ko.activation_ops() {
256                return act.gelu(input);
257            }
258        }
259        self.tensor_ops.gelu(input)
260    }
261
262    /// SiLU: prefer `KernelOps::ActivationOps::silu_mul` is *fused* so there
263    /// is no direct `TensorOps` equivalent. This helper exposes the non-fused
264    /// `TensorOps::silu` for callers that only need plain SiLU.
265    pub fn silu(&self, input: &TensorRef) -> Result<TensorRef> {
266        self.tensor_ops.silu(input)
267    }
268
269    /// Fused SiLU-multiply (SwiGLU building block).
270    /// Falls back to `silu(gate) * up` via TensorOps when kernel is unavailable.
271    pub fn silu_mul(&self, gate: &TensorRef, up: &TensorRef) -> Result<TensorRef> {
272        if let Some(ko) = self.kernel_ops {
273            if let Some(act) = ko.activation_ops() {
274                return act.silu_mul(gate, up);
275            }
276        }
277        // Fallback: silu(gate) * up
278        let activated = self.tensor_ops.silu(gate)?;
279        self.tensor_ops.mul(&activated, up)
280    }
281
282    /// Dense linear (no bias).
283    /// Falls back to `TensorOps::matmul`.
284    pub fn linear(&self, input: &TensorRef, weight: &TensorRef) -> Result<TensorRef> {
285        if let Some(ko) = self.kernel_ops {
286            if let Some(lin) = ko.linear_ops() {
287                return lin.linear(input, weight);
288            }
289        }
290        self.tensor_ops.matmul(input, weight)
291    }
292
293    /// Softmax: always via `TensorOps` (no kernel sub-trait for plain softmax).
294    pub fn softmax(&self, input: &TensorRef, dim: i32) -> Result<TensorRef> {
295        self.tensor_ops.softmax(input, dim)
296    }
297
298    /// Access the underlying `KernelOps` (if any) for ops that have no
299    /// `TensorOps` fallback (e.g. rotary_embedding, attention, sampling).
300    pub fn kernel_ops(&self) -> Option<&'a dyn KernelOps> {
301        self.kernel_ops
302    }
303
304    /// Access the underlying `TensorOps`.
305    pub fn tensor_ops(&self) -> &'a dyn crate::TensorOps {
306        self.tensor_ops
307    }
308}
309
310// ---------------------------------------------------------------------------
311// Tests
312// ---------------------------------------------------------------------------
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    // A minimal KernelOps that returns None for everything.
319    struct EmptyKernelOps;
320    impl KernelOps for EmptyKernelOps {
321        fn backend_name(&self) -> &str {
322            "empty"
323        }
324    }
325
326    #[test]
327    fn test_empty_kernel_ops_returns_none() {
328        let ops = EmptyKernelOps;
329        assert!(ops.norm_ops().is_none());
330        assert!(ops.position_ops().is_none());
331        assert!(ops.attention_ops().is_none());
332        assert!(ops.activation_ops().is_none());
333        assert!(ops.linear_ops().is_none());
334        assert!(ops.sampling_ops().is_none());
335        assert_eq!(ops.backend_name(), "empty");
336    }
337
338    #[test]
339    fn test_rope_config_default() {
340        let cfg = RoPEConfig {
341            head_dim: 128,
342            max_seq_len: 2048,
343            theta: 10000.0,
344        };
345        assert_eq!(cfg.head_dim, 128);
346    }
347
348    #[test]
349    fn test_attention_params() {
350        let params = AttentionParams {
351            num_heads: 32,
352            num_kv_heads: 8,
353            head_dim: 128,
354            softmax_scale: 1.0 / (128.0_f32).sqrt(),
355            causal: true,
356        };
357        assert!(params.causal);
358        assert_eq!(params.num_heads / params.num_kv_heads, 4); // GQA ratio
359    }
360
361    #[test]
362    fn test_quant_scheme_variants() {
363        let q4 = QuantScheme::Q4_0 { group_size: 32 };
364        let q8 = QuantScheme::Q8_0 { group_size: 128 };
365        match q4 {
366            QuantScheme::Q4_0 { group_size } => assert_eq!(group_size, 32),
367            _ => panic!("expected Q4_0"),
368        }
369        match q8 {
370            QuantScheme::Q8_0 { group_size } => assert_eq!(group_size, 128),
371            _ => panic!("expected Q8_0"),
372        }
373    }
374}