ferrum_interfaces/kernel_ops.rs
1//! Kernel backend abstraction layer for LLM-specific fused operations.
2//!
3//! This module defines a mid-level abstraction between raw `KernelExecutor`
4//! (too low-level: grid/block sizes) and `TensorOps` (too high-level: no
5//! LLM-specific fused ops). It enables pluggable CUDA/Metal/CPU backends
6//! through six focused sub-traits composed into one umbrella `KernelOps`.
7
8use crate::TensorRef;
9use ferrum_types::Result;
10
11// ---------------------------------------------------------------------------
12// Configuration structs
13// ---------------------------------------------------------------------------
14
15/// Rotary position embedding configuration.
16#[derive(Debug, Clone)]
17pub struct RoPEConfig {
18 /// Dimensionality of each attention head.
19 pub head_dim: usize,
20 /// Maximum sequence length the cache covers.
21 pub max_seq_len: usize,
22 /// Base frequency (default 10000.0 for standard RoPE).
23 pub theta: f32,
24}
25
26/// Parameters describing a single attention call.
27#[derive(Debug, Clone)]
28pub struct AttentionParams {
29 pub num_heads: usize,
30 pub num_kv_heads: usize,
31 pub head_dim: usize,
32 /// Softmax scale (typically `1 / sqrt(head_dim)`).
33 pub softmax_scale: f32,
34 /// Whether to apply a causal mask.
35 pub causal: bool,
36}
37
38/// Quantization scheme descriptor for quantized linear ops.
39#[derive(Debug, Clone)]
40pub enum QuantScheme {
41 /// 4-bit quantization with group size (e.g. Q4_0 uses group_size=32).
42 Q4_0 { group_size: usize },
43 /// 8-bit quantization.
44 Q8_0 { group_size: usize },
45}
46
47/// Sampling parameters for GPU-side token sampling.
48#[derive(Debug, Clone)]
49pub struct SamplingParams {
50 pub temperature: f32,
51 pub top_k: usize,
52 pub top_p: f32,
53 pub repetition_penalty: f32,
54 /// Token IDs that have appeared in context (for repetition penalty).
55 pub repetition_token_ids: Vec<u32>,
56 /// Frequency count for each token in `repetition_token_ids`.
57 pub repetition_token_freqs: Vec<u32>,
58 /// Per-step RNG seed.
59 pub rng_seed: u32,
60}
61
62// ---------------------------------------------------------------------------
63// Sub-traits
64// ---------------------------------------------------------------------------
65
66/// Normalization operations.
67pub trait NormOps: Send + Sync {
68 /// RMS normalization: `x / rms(x) * weight`.
69 fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef>;
70
71 /// Fused RMS normalization with residual add:
72 /// `output = rms_norm(input + residual, weight, eps)`.
73 /// Returns `(normed_output, updated_residual)`.
74 fn rms_norm_residual(
75 &self,
76 input: &TensorRef,
77 residual: &TensorRef,
78 weight: &TensorRef,
79 eps: f32,
80 ) -> Result<(TensorRef, TensorRef)> {
81 // Default: add then norm separately.
82 let _ = (input, residual, weight, eps);
83 Err(ferrum_types::FerrumError::unsupported(
84 "rms_norm_residual not implemented",
85 ))
86 }
87}
88
89/// Positional encoding operations.
90pub trait PositionOps: Send + Sync {
91 /// Apply rotary position embedding to a Q or K tensor.
92 ///
93 /// `x` shape: `[batch, seq_len, num_heads, head_dim]`
94 /// `position_ids`: position indices for each token in the sequence.
95 /// `cos_cache` / `sin_cache`: precomputed `[max_seq_len, head_dim/2]`.
96 fn rotary_embedding(
97 &self,
98 x: &TensorRef,
99 cos_cache: &TensorRef,
100 sin_cache: &TensorRef,
101 position_ids: &[usize],
102 ) -> Result<TensorRef>;
103}
104
105/// Attention operations.
106pub trait AttentionOps: Send + Sync {
107 /// Standard multi-head / grouped-query attention.
108 ///
109 /// * `q` — `[batch, seq_q, num_heads, head_dim]`
110 /// * `k` — `[batch, seq_kv, num_kv_heads, head_dim]`
111 /// * `v` — `[batch, seq_kv, num_kv_heads, head_dim]`
112 ///
113 /// Returns attention output `[batch, seq_q, num_heads, head_dim]`.
114 fn attention(
115 &self,
116 q: &TensorRef,
117 k: &TensorRef,
118 v: &TensorRef,
119 params: &AttentionParams,
120 ) -> Result<TensorRef>;
121
122 /// Paged attention for KV-cache-based decode.
123 ///
124 /// Default returns unsupported — backends opt in.
125 fn paged_attention(
126 &self,
127 _q: &TensorRef,
128 _k_cache: &TensorRef,
129 _v_cache: &TensorRef,
130 _block_table: &[u32],
131 _params: &AttentionParams,
132 ) -> Result<TensorRef> {
133 Err(ferrum_types::FerrumError::unsupported(
134 "paged_attention not implemented",
135 ))
136 }
137}
138
139/// Activation function operations (including fused variants).
140pub trait ActivationOps: Send + Sync {
141 /// Fused SiLU-multiply: `silu(gate) * up`.
142 ///
143 /// This is the SwiGLU building block used in LLaMA/Qwen MLPs.
144 fn silu_mul(&self, gate: &TensorRef, up: &TensorRef) -> Result<TensorRef>;
145
146 /// GELU activation.
147 fn gelu(&self, input: &TensorRef) -> Result<TensorRef>;
148}
149
150/// Linear / matrix-multiply operations.
151pub trait LinearOps: Send + Sync {
152 /// Dense linear projection (no bias): `input @ weight^T`.
153 ///
154 /// * `input` — `[..., in_features]`
155 /// * `weight` — `[out_features, in_features]`
156 fn linear(&self, input: &TensorRef, weight: &TensorRef) -> Result<TensorRef>;
157
158 /// Quantized linear projection.
159 ///
160 /// `packed_weight` is backend-specific packed data (e.g. Q4_0 blocks).
161 fn quantized_linear(
162 &self,
163 _input: &TensorRef,
164 _packed_weight: &TensorRef,
165 _scheme: &QuantScheme,
166 ) -> Result<TensorRef> {
167 Err(ferrum_types::FerrumError::unsupported(
168 "quantized_linear not implemented",
169 ))
170 }
171}
172
173/// Token sampling operations (GPU-side when possible).
174pub trait SamplingOps: Send + Sync {
175 /// Sample a single token from logits using the full sampling pipeline.
176 ///
177 /// `logits` shape: `[vocab_size]` or `[1, vocab_size]` (last-token logits).
178 fn sample_token(&self, logits: &TensorRef, params: &SamplingParams) -> Result<u32>;
179
180 /// Greedy argmax over the last dimension.
181 fn argmax(&self, logits: &TensorRef) -> Result<u32>;
182}
183
184// ---------------------------------------------------------------------------
185// Umbrella trait
186// ---------------------------------------------------------------------------
187
188/// Unified kernel operations interface.
189///
190/// Backends implement whichever sub-traits they support and return `None` for
191/// the rest. Callers use `KernelOpsDispatch` (below) to get automatic fallback
192/// to `TensorOps` when a sub-trait is unavailable.
193pub trait KernelOps: Send + Sync {
194 fn norm_ops(&self) -> Option<&dyn NormOps> {
195 None
196 }
197 fn position_ops(&self) -> Option<&dyn PositionOps> {
198 None
199 }
200 fn attention_ops(&self) -> Option<&dyn AttentionOps> {
201 None
202 }
203 fn activation_ops(&self) -> Option<&dyn ActivationOps> {
204 None
205 }
206 fn linear_ops(&self) -> Option<&dyn LinearOps> {
207 None
208 }
209 fn sampling_ops(&self) -> Option<&dyn SamplingOps> {
210 None
211 }
212
213 /// Human-readable backend identifier (e.g. `"candle"`, `"metal"`, `"cuda"`).
214 fn backend_name(&self) -> &str;
215}
216
217// ---------------------------------------------------------------------------
218// Dispatch helper (Step 3)
219// ---------------------------------------------------------------------------
220
221/// Dispatch wrapper that tries `KernelOps` first, then falls back to
222/// `TensorOps` for operations that have a `TensorOps` equivalent.
223///
224/// This enables gradual migration: callers use the dispatch without caring
225/// which path actually runs.
226pub struct KernelOpsDispatch<'a> {
227 kernel_ops: Option<&'a dyn KernelOps>,
228 tensor_ops: &'a dyn crate::TensorOps,
229}
230
231impl<'a> KernelOpsDispatch<'a> {
232 pub fn new(
233 kernel_ops: Option<&'a dyn KernelOps>,
234 tensor_ops: &'a dyn crate::TensorOps,
235 ) -> Self {
236 Self {
237 kernel_ops,
238 tensor_ops,
239 }
240 }
241
242 /// RMS norm: prefer `KernelOps::NormOps`, fall back to `TensorOps::rms_norm`.
243 pub fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef> {
244 if let Some(ko) = self.kernel_ops {
245 if let Some(norm) = ko.norm_ops() {
246 return norm.rms_norm(input, weight, eps);
247 }
248 }
249 self.tensor_ops.rms_norm(input, weight, eps)
250 }
251
252 /// GELU: prefer `KernelOps::ActivationOps`, fall back to `TensorOps::gelu`.
253 pub fn gelu(&self, input: &TensorRef) -> Result<TensorRef> {
254 if let Some(ko) = self.kernel_ops {
255 if let Some(act) = ko.activation_ops() {
256 return act.gelu(input);
257 }
258 }
259 self.tensor_ops.gelu(input)
260 }
261
262 /// SiLU: prefer `KernelOps::ActivationOps::silu_mul` is *fused* so there
263 /// is no direct `TensorOps` equivalent. This helper exposes the non-fused
264 /// `TensorOps::silu` for callers that only need plain SiLU.
265 pub fn silu(&self, input: &TensorRef) -> Result<TensorRef> {
266 self.tensor_ops.silu(input)
267 }
268
269 /// Fused SiLU-multiply (SwiGLU building block).
270 /// Falls back to `silu(gate) * up` via TensorOps when kernel is unavailable.
271 pub fn silu_mul(&self, gate: &TensorRef, up: &TensorRef) -> Result<TensorRef> {
272 if let Some(ko) = self.kernel_ops {
273 if let Some(act) = ko.activation_ops() {
274 return act.silu_mul(gate, up);
275 }
276 }
277 // Fallback: silu(gate) * up
278 let activated = self.tensor_ops.silu(gate)?;
279 self.tensor_ops.mul(&activated, up)
280 }
281
282 /// Dense linear (no bias).
283 /// Falls back to `TensorOps::matmul`.
284 pub fn linear(&self, input: &TensorRef, weight: &TensorRef) -> Result<TensorRef> {
285 if let Some(ko) = self.kernel_ops {
286 if let Some(lin) = ko.linear_ops() {
287 return lin.linear(input, weight);
288 }
289 }
290 self.tensor_ops.matmul(input, weight)
291 }
292
293 /// Softmax: always via `TensorOps` (no kernel sub-trait for plain softmax).
294 pub fn softmax(&self, input: &TensorRef, dim: i32) -> Result<TensorRef> {
295 self.tensor_ops.softmax(input, dim)
296 }
297
298 /// Access the underlying `KernelOps` (if any) for ops that have no
299 /// `TensorOps` fallback (e.g. rotary_embedding, attention, sampling).
300 pub fn kernel_ops(&self) -> Option<&'a dyn KernelOps> {
301 self.kernel_ops
302 }
303
304 /// Access the underlying `TensorOps`.
305 pub fn tensor_ops(&self) -> &'a dyn crate::TensorOps {
306 self.tensor_ops
307 }
308}
309
310// ---------------------------------------------------------------------------
311// Tests
312// ---------------------------------------------------------------------------
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 // A minimal KernelOps that returns None for everything.
319 struct EmptyKernelOps;
320 impl KernelOps for EmptyKernelOps {
321 fn backend_name(&self) -> &str {
322 "empty"
323 }
324 }
325
326 #[test]
327 fn test_empty_kernel_ops_returns_none() {
328 let ops = EmptyKernelOps;
329 assert!(ops.norm_ops().is_none());
330 assert!(ops.position_ops().is_none());
331 assert!(ops.attention_ops().is_none());
332 assert!(ops.activation_ops().is_none());
333 assert!(ops.linear_ops().is_none());
334 assert!(ops.sampling_ops().is_none());
335 assert_eq!(ops.backend_name(), "empty");
336 }
337
338 #[test]
339 fn test_rope_config_default() {
340 let cfg = RoPEConfig {
341 head_dim: 128,
342 max_seq_len: 2048,
343 theta: 10000.0,
344 };
345 assert_eq!(cfg.head_dim, 128);
346 }
347
348 #[test]
349 fn test_attention_params() {
350 let params = AttentionParams {
351 num_heads: 32,
352 num_kv_heads: 8,
353 head_dim: 128,
354 softmax_scale: 1.0 / (128.0_f32).sqrt(),
355 causal: true,
356 };
357 assert!(params.causal);
358 assert_eq!(params.num_heads / params.num_kv_heads, 4); // GQA ratio
359 }
360
361 #[test]
362 fn test_quant_scheme_variants() {
363 let q4 = QuantScheme::Q4_0 { group_size: 32 };
364 let q8 = QuantScheme::Q8_0 { group_size: 128 };
365 match q4 {
366 QuantScheme::Q4_0 { group_size } => assert_eq!(group_size, 32),
367 _ => panic!("expected Q4_0"),
368 }
369 match q8 {
370 QuantScheme::Q8_0 { group_size } => assert_eq!(group_size, 128),
371 _ => panic!("expected Q8_0"),
372 }
373 }
374}