Skip to main content

ferrum_kernels/backend/
kv_layer.rs

1//! `KvLayer<B>` — per-K-dtype trait that picks the cache layout type and
2//! the K-specific paged write / read launchers (Dim 5 PR C trait-based
3//! dispatch).
4//!
5//! ## Why a trait, not an enum
6//!
7//! - `K` carries an associated `Layer` type — FP16 → `KvCache<B, KvFp16>`,
8//!   INT8 → `KvCacheQuant<B, KvInt8>`.
9//! - K-specific launchers (paged write + paged decode attention; contig
10//!   write + contig decode for FP16) are trait methods. The model bound
11//!   `where K: KvLayer<B>` lets `K::method(layer, ...)` dispatch directly
12//!   to the right backend launcher per (B, K) at monomorphization time —
13//!   no runtime tag, no enum match, no panicking accessors.
14//! - `LlamaFamilyModel<CpuBackend, KvInt8>` is a compile error because
15//!   `KvInt8: KvLayer<CpuBackend>` doesn't hold (CPU backend has no
16//!   `BackendInt8KvOps` impl).
17
18use ferrum_types::{FerrumError, Result};
19
20use crate::backend::{Backend, BackendInt8KvOps, KvCache, KvCacheQuant};
21use ferrum_interfaces::kv_dtype::{KvDtypeKind, KvFp16, KvInt8};
22
23/// Per-K-dtype dispatch trait.
24#[allow(clippy::too_many_arguments)]
25pub trait KvLayer<B: Backend>: KvDtypeKind {
26    /// Per-layer cache type (FP16 → `KvCache`, INT8 → `KvCacheQuant`).
27    type Layer: Send + Sync;
28
29    /// Allocate a paged cache layer for one sequence.
30    fn alloc_paged(
31        max_blocks_per_seq: usize,
32        block_size: usize,
33        num_kv_heads: usize,
34        head_dim: usize,
35    ) -> Self::Layer;
36
37    /// Allocate a contiguous cache layer (FP16 only; INT8 panics).
38    fn alloc_contig(capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self::Layer;
39
40    // Metadata accessors (variant-agnostic).
41    fn len(layer: &Self::Layer) -> usize;
42    fn set_len(layer: &mut Self::Layer, new_len: usize);
43    fn capacity(layer: &Self::Layer) -> usize;
44    fn block_size(layer: &Self::Layer) -> usize;
45    fn num_kv_heads(layer: &Self::Layer) -> usize;
46    fn head_dim(layer: &Self::Layer) -> usize;
47    fn block_table(layer: &Self::Layer) -> Option<&B::Buffer>;
48    fn block_table_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer>;
49    fn context_lens(layer: &Self::Layer) -> Option<&B::Buffer>;
50    fn context_lens_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer>;
51    fn paged_block_indices(layer: &Self::Layer) -> &[u32];
52    fn paged_block_indices_mut(layer: &mut Self::Layer) -> &mut Vec<u32>;
53
54    fn is_paged(layer: &Self::Layer) -> bool {
55        Self::block_size(layer) > 0
56    }
57
58    /// Paged write: split QKV → norm → RoPE → write K/V into the paged
59    /// pool. FP16 uses `B::split_qkv_norm_rope_into_paged_cache`. INT8
60    /// uses `B::split_qkv_norm_rope` + `B::int8_kv_append_paged`.
61    fn paged_write(
62        ctx: &mut B::Context,
63        layer: &mut Self::Layer,
64        qkv: &B::Buffer,
65        q_norm_w: &B::Buffer,
66        k_norm_w: &B::Buffer,
67        cos: &B::Buffer,
68        sin: &B::Buffer,
69        q_out: &mut B::Buffer,
70        k_scratch: &mut B::Buffer,
71        v_scratch: &mut B::Buffer,
72        pool_k: &mut B::Buffer,
73        pool_v: &mut B::Buffer,
74        tokens: usize,
75        num_q_heads: usize,
76        num_kv_heads: usize,
77        head_dim: usize,
78        pos_offset: usize,
79        eps: f32,
80        qk_mode: i32,
81    ) -> Result<()>;
82
83    /// Paged decode attention. Reads from the per-layer cache, writes the
84    /// attended output to `output`. FP16 reads from `pool_k`/`pool_v`;
85    /// INT8 reads from layer-internal INT8 buffers (pool args ignored).
86    fn paged_decode_attention(
87        ctx: &mut B::Context,
88        layer: &mut Self::Layer,
89        q: &B::Buffer,
90        pool_k: &B::Buffer,
91        pool_v: &B::Buffer,
92        output: &mut B::Buffer,
93        num_q_heads: usize,
94        num_kv_heads: usize,
95        head_dim: usize,
96        final_kv_len: usize,
97        tokens: usize,
98    ) -> Result<()>;
99
100    /// Contig write: FP16 only. INT8 inherits the panic default —
101    /// `KvInt8::alloc_contig` panics in `ensure_kv`, so this branch is
102    /// dead code on the INT8 path.
103    fn contig_write(
104        _ctx: &mut B::Context,
105        _layer: &mut Self::Layer,
106        _qkv: &B::Buffer,
107        _q_norm_w: &B::Buffer,
108        _k_norm_w: &B::Buffer,
109        _cos: &B::Buffer,
110        _sin: &B::Buffer,
111        _q_out: &mut B::Buffer,
112        _k_scratch: &mut B::Buffer,
113        _v_scratch: &mut B::Buffer,
114        _q_buf: &mut B::Buffer,
115        _k_buf: &mut B::Buffer,
116        _v_buf: &mut B::Buffer,
117        _tokens: usize,
118        _num_q_heads: usize,
119        _num_kv_heads: usize,
120        _head_dim: usize,
121        _pos_offset: usize,
122        _eps: f32,
123        _qk_mode: i32,
124    ) -> Result<()> {
125        unimplemented!("contig_write: not supported for this K dtype")
126    }
127
128    /// Contig decode attention: FP16 only.
129    fn contig_decode_attention(
130        _ctx: &mut B::Context,
131        _layer: &Self::Layer,
132        _q: &B::Buffer,
133        _output: &mut B::Buffer,
134        _attn_cfg: crate::backend::AttnConfig,
135        _tokens: usize,
136        _pos_offset: usize,
137    ) -> Result<()> {
138        unimplemented!("contig_decode_attention: not supported for this K dtype")
139    }
140}
141
142// ─────────────────────────────────────────────────────────────────────
143// FP16 impl
144// ─────────────────────────────────────────────────────────────────────
145
146impl<B: Backend + crate::backend::BackendPagedKv> KvLayer<B> for KvFp16 {
147    type Layer = KvCache<B, KvFp16>;
148
149    fn alloc_paged(
150        max_blocks_per_seq: usize,
151        block_size: usize,
152        num_kv_heads: usize,
153        head_dim: usize,
154    ) -> Self::Layer {
155        let block_table = B::alloc_typed(crate::backend::Dtype::U32, max_blocks_per_seq);
156        let mut context_lens = B::alloc_typed(crate::backend::Dtype::U32, 1);
157        let mut bt_ctx = B::new_context();
158        B::write_typed::<u32>(&mut bt_ctx, &mut context_lens, &[0u32]);
159        B::sync(&mut bt_ctx);
160        KvCache {
161            k: B::alloc(1),
162            v: B::alloc(1),
163            len: 0,
164            capacity: max_blocks_per_seq * block_size,
165            num_kv_heads,
166            head_dim,
167            block_size,
168            block_table: Some(block_table),
169            context_lens: Some(context_lens),
170            paged_block_indices: Vec::new(),
171            _kv_dtype: std::marker::PhantomData,
172        }
173    }
174
175    fn alloc_contig(capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self::Layer {
176        KvCache {
177            k: B::alloc(num_kv_heads * capacity * head_dim),
178            v: B::alloc(num_kv_heads * capacity * head_dim),
179            len: 0,
180            capacity,
181            num_kv_heads,
182            head_dim,
183            block_size: 0,
184            block_table: None,
185            context_lens: None,
186            paged_block_indices: Vec::new(),
187            _kv_dtype: std::marker::PhantomData,
188        }
189    }
190
191    fn len(layer: &Self::Layer) -> usize {
192        layer.len
193    }
194    fn set_len(layer: &mut Self::Layer, new_len: usize) {
195        layer.len = new_len;
196    }
197    fn capacity(layer: &Self::Layer) -> usize {
198        layer.capacity
199    }
200    fn block_size(layer: &Self::Layer) -> usize {
201        layer.block_size
202    }
203    fn num_kv_heads(layer: &Self::Layer) -> usize {
204        layer.num_kv_heads
205    }
206    fn head_dim(layer: &Self::Layer) -> usize {
207        layer.head_dim
208    }
209    fn block_table(layer: &Self::Layer) -> Option<&B::Buffer> {
210        layer.block_table.as_ref()
211    }
212    fn block_table_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer> {
213        layer.block_table.as_mut()
214    }
215    fn context_lens(layer: &Self::Layer) -> Option<&B::Buffer> {
216        layer.context_lens.as_ref()
217    }
218    fn context_lens_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer> {
219        layer.context_lens.as_mut()
220    }
221    fn paged_block_indices(layer: &Self::Layer) -> &[u32] {
222        &layer.paged_block_indices
223    }
224    fn paged_block_indices_mut(layer: &mut Self::Layer) -> &mut Vec<u32> {
225        &mut layer.paged_block_indices
226    }
227
228    fn paged_write(
229        ctx: &mut B::Context,
230        layer: &mut Self::Layer,
231        qkv: &B::Buffer,
232        q_norm_w: &B::Buffer,
233        k_norm_w: &B::Buffer,
234        cos: &B::Buffer,
235        sin: &B::Buffer,
236        q_out: &mut B::Buffer,
237        _k_scratch: &mut B::Buffer,
238        _v_scratch: &mut B::Buffer,
239        pool_k: &mut B::Buffer,
240        pool_v: &mut B::Buffer,
241        tokens: usize,
242        num_q_heads: usize,
243        num_kv_heads: usize,
244        head_dim: usize,
245        pos_offset: usize,
246        eps: f32,
247        qk_mode: i32,
248    ) -> Result<()> {
249        let block_size = layer.block_size;
250        let cache_len_before = layer.len;
251        let num_blocks_per_seq = layer.capacity / block_size;
252        let bt = layer
253            .block_table
254            .as_ref()
255            .ok_or_else(|| FerrumError::model("FP16 paged_write: missing block_table"))?;
256        B::split_qkv_norm_rope_into_paged_cache(
257            ctx,
258            qkv,
259            0,
260            q_norm_w,
261            k_norm_w,
262            cos,
263            sin,
264            q_out,
265            0,
266            pool_k,
267            pool_v,
268            bt,
269            tokens,
270            num_q_heads,
271            num_kv_heads,
272            head_dim,
273            pos_offset,
274            eps,
275            qk_mode,
276            cache_len_before,
277            block_size,
278            num_blocks_per_seq,
279        )
280    }
281
282    fn paged_decode_attention(
283        ctx: &mut B::Context,
284        layer: &mut Self::Layer,
285        q: &B::Buffer,
286        pool_k: &B::Buffer,
287        pool_v: &B::Buffer,
288        output: &mut B::Buffer,
289        num_q_heads: usize,
290        num_kv_heads: usize,
291        head_dim: usize,
292        final_kv_len: usize,
293        tokens: usize,
294    ) -> Result<()> {
295        let block_size = layer.block_size;
296        let num_blocks_per_seq = layer.capacity / block_size;
297        let bt_ptr = layer
298            .block_table
299            .as_ref()
300            .ok_or_else(|| FerrumError::model("FP16 paged_decode: missing block_table"))?
301            as *const B::Buffer;
302        let cl_buf = layer
303            .context_lens
304            .as_mut()
305            .ok_or_else(|| FerrumError::model("FP16 paged_decode: missing context_lens"))?;
306        B::write_typed::<u32>(ctx, cl_buf, &[final_kv_len as u32]);
307        // SAFETY: block_table outlives the call.
308        let bt = unsafe { &*bt_ptr };
309        let cl = layer.context_lens.as_ref().unwrap();
310        B::paged_decode_attention(
311            ctx,
312            q,
313            pool_k,
314            pool_v,
315            output,
316            bt,
317            cl,
318            1,
319            num_q_heads,
320            num_kv_heads,
321            head_dim,
322            block_size,
323            num_blocks_per_seq,
324            tokens,
325        )
326    }
327
328    fn contig_write(
329        ctx: &mut B::Context,
330        layer: &mut Self::Layer,
331        qkv: &B::Buffer,
332        q_norm_w: &B::Buffer,
333        k_norm_w: &B::Buffer,
334        cos: &B::Buffer,
335        sin: &B::Buffer,
336        q_out: &mut B::Buffer,
337        k_scratch: &mut B::Buffer,
338        v_scratch: &mut B::Buffer,
339        q_buf: &mut B::Buffer,
340        k_buf: &mut B::Buffer,
341        v_buf: &mut B::Buffer,
342        tokens: usize,
343        num_q_heads: usize,
344        num_kv_heads: usize,
345        head_dim: usize,
346        pos_offset: usize,
347        eps: f32,
348        qk_mode: i32,
349    ) -> Result<()> {
350        let cache_len_before = layer.len;
351        let cache_capacity = layer.capacity;
352        let used_into_cache = B::split_qkv_norm_rope_into_cache(
353            ctx,
354            qkv,
355            q_norm_w,
356            k_norm_w,
357            cos,
358            sin,
359            q_out,
360            &mut layer.k,
361            &mut layer.v,
362            tokens,
363            num_q_heads,
364            num_kv_heads,
365            head_dim,
366            pos_offset,
367            eps,
368            qk_mode,
369            cache_len_before,
370            cache_capacity,
371        )
372        .is_ok();
373        if used_into_cache {
374            return Ok(());
375        }
376        let used_fused_qkv = B::split_qkv_norm_rope(
377            ctx,
378            qkv,
379            q_norm_w,
380            k_norm_w,
381            cos,
382            sin,
383            q_out,
384            k_scratch,
385            v_scratch,
386            tokens,
387            num_q_heads,
388            num_kv_heads,
389            head_dim,
390            pos_offset,
391            eps,
392            qk_mode,
393        )
394        .is_ok();
395        if !used_fused_qkv {
396            let q_dim = num_q_heads * head_dim;
397            let kv_dim = num_kv_heads * head_dim;
398            B::split_qkv(ctx, qkv, q_buf, k_buf, v_buf, tokens, q_dim, kv_dim);
399            B::qk_norm_rope(
400                ctx,
401                q_buf,
402                q_norm_w,
403                cos,
404                sin,
405                q_out,
406                tokens,
407                num_q_heads,
408                head_dim,
409                pos_offset,
410                eps,
411                qk_mode,
412            );
413            B::qk_norm_rope(
414                ctx,
415                k_buf,
416                k_norm_w,
417                cos,
418                sin,
419                k_scratch,
420                tokens,
421                num_kv_heads,
422                head_dim,
423                pos_offset,
424                eps,
425                qk_mode,
426            );
427            B::qk_norm_rope(
428                ctx,
429                v_buf,
430                q_norm_w,
431                cos,
432                sin,
433                v_scratch,
434                tokens,
435                num_kv_heads,
436                head_dim,
437                pos_offset,
438                eps,
439                0,
440            );
441        }
442        B::kv_cache_append_head_major(
443            ctx,
444            &mut layer.k,
445            &mut layer.v,
446            cache_len_before,
447            cache_capacity,
448            k_scratch,
449            v_scratch,
450            tokens,
451            num_kv_heads,
452            head_dim,
453        );
454        Ok(())
455    }
456
457    fn contig_decode_attention(
458        ctx: &mut B::Context,
459        layer: &Self::Layer,
460        q: &B::Buffer,
461        output: &mut B::Buffer,
462        attn_cfg: crate::backend::AttnConfig,
463        tokens: usize,
464        pos_offset: usize,
465    ) -> Result<()> {
466        let kv_len = layer.len;
467        B::flash_attention(
468            ctx, q, &layer.k, &layer.v, output, 1, tokens, kv_len, pos_offset, &attn_cfg,
469        );
470        Ok(())
471    }
472}
473
474// ─────────────────────────────────────────────────────────────────────
475// INT8 impl
476// ─────────────────────────────────────────────────────────────────────
477
478impl<B: Backend + BackendInt8KvOps> KvLayer<B> for KvInt8 {
479    type Layer = KvCacheQuant<B, KvInt8>;
480
481    fn alloc_paged(
482        max_blocks_per_seq: usize,
483        block_size: usize,
484        num_kv_heads: usize,
485        head_dim: usize,
486    ) -> Self::Layer {
487        B::alloc_paged_int8_layer(max_blocks_per_seq, block_size, num_kv_heads, head_dim)
488    }
489
490    fn alloc_contig(_capacity: usize, _num_kv_heads: usize, _head_dim: usize) -> Self::Layer {
491        panic!("KvInt8::alloc_contig: INT8 KV is paged-only")
492    }
493
494    fn len(layer: &Self::Layer) -> usize {
495        layer.len
496    }
497    fn set_len(layer: &mut Self::Layer, new_len: usize) {
498        layer.len = new_len;
499    }
500    fn capacity(layer: &Self::Layer) -> usize {
501        layer.capacity
502    }
503    fn block_size(layer: &Self::Layer) -> usize {
504        layer.block_size
505    }
506    fn num_kv_heads(layer: &Self::Layer) -> usize {
507        layer.num_kv_heads
508    }
509    fn head_dim(layer: &Self::Layer) -> usize {
510        layer.head_dim
511    }
512    fn block_table(layer: &Self::Layer) -> Option<&B::Buffer> {
513        layer.block_table.as_ref()
514    }
515    fn block_table_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer> {
516        layer.block_table.as_mut()
517    }
518    fn context_lens(layer: &Self::Layer) -> Option<&B::Buffer> {
519        layer.context_lens.as_ref()
520    }
521    fn context_lens_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer> {
522        layer.context_lens.as_mut()
523    }
524    fn paged_block_indices(layer: &Self::Layer) -> &[u32] {
525        &layer.paged_block_indices
526    }
527    fn paged_block_indices_mut(layer: &mut Self::Layer) -> &mut Vec<u32> {
528        &mut layer.paged_block_indices
529    }
530
531    fn paged_write(
532        ctx: &mut B::Context,
533        layer: &mut Self::Layer,
534        qkv: &B::Buffer,
535        q_norm_w: &B::Buffer,
536        k_norm_w: &B::Buffer,
537        cos: &B::Buffer,
538        sin: &B::Buffer,
539        q_out: &mut B::Buffer,
540        k_scratch: &mut B::Buffer,
541        v_scratch: &mut B::Buffer,
542        _pool_k: &mut B::Buffer,
543        _pool_v: &mut B::Buffer,
544        tokens: usize,
545        num_q_heads: usize,
546        num_kv_heads: usize,
547        head_dim: usize,
548        pos_offset: usize,
549        eps: f32,
550        qk_mode: i32,
551    ) -> Result<()> {
552        // 1. split + norm + RoPE → FP16 head-major scratch (k/v_scratch).
553        B::split_qkv_norm_rope(
554            ctx,
555            qkv,
556            q_norm_w,
557            k_norm_w,
558            cos,
559            sin,
560            q_out,
561            k_scratch,
562            v_scratch,
563            tokens,
564            num_q_heads,
565            num_kv_heads,
566            head_dim,
567            pos_offset,
568            eps,
569            qk_mode,
570        )?;
571        // 2. quantize FP16 → INT8 + per-token scales, paged append.
572        // `paged_block_indices` is the host-side mirror populated at
573        // `ensure_kv` time — passing it directly avoids the D2H + sync
574        // barrier that would otherwise dominate per-token overhead.
575        let cache_len_before = layer.len;
576        let block_size = layer.block_size;
577        // Clone the host indices (small Vec<u32>) so we don't hold a
578        // borrow on `layer` while passing &mut layer.k/v/scales below.
579        let paged_indices: Vec<u32> = layer.paged_block_indices.clone();
580        B::int8_kv_append_paged(
581            ctx,
582            k_scratch,
583            v_scratch,
584            &mut layer.k,
585            &mut layer.v,
586            &mut layer.k_scales,
587            &mut layer.v_scales,
588            &paged_indices,
589            cache_len_before,
590            tokens,
591            block_size,
592            num_kv_heads,
593            head_dim,
594        )
595    }
596
597    fn paged_decode_attention(
598        ctx: &mut B::Context,
599        layer: &mut Self::Layer,
600        q: &B::Buffer,
601        _pool_k: &B::Buffer,
602        _pool_v: &B::Buffer,
603        output: &mut B::Buffer,
604        num_q_heads: usize,
605        num_kv_heads: usize,
606        head_dim: usize,
607        final_kv_len: usize,
608        _tokens: usize,
609    ) -> Result<()> {
610        let block_size = layer.block_size;
611        let cl_buf = layer
612            .context_lens
613            .as_mut()
614            .ok_or_else(|| FerrumError::model("INT8 paged_decode: missing context_lens"))?;
615        B::write_typed::<u32>(ctx, cl_buf, &[final_kv_len as u32]);
616        let bt = layer
617            .block_table
618            .as_ref()
619            .ok_or_else(|| FerrumError::model("INT8 paged_decode: missing block_table"))?;
620        let scale = (head_dim as f32).sqrt().recip();
621        B::int8_paged_decode_attention(
622            ctx,
623            q,
624            &layer.k,
625            &layer.v,
626            &layer.k_scales,
627            &layer.v_scales,
628            bt,
629            output,
630            num_q_heads,
631            num_kv_heads,
632            head_dim,
633            final_kv_len,
634            block_size,
635            scale,
636        )
637    }
638}