1use ferrum_types::{FerrumError, Result};
19
20use crate::backend::{Backend, BackendInt8KvOps, KvCache, KvCacheQuant};
21use ferrum_interfaces::kv_dtype::{KvDtypeKind, KvFp16, KvInt8};
22
23#[allow(clippy::too_many_arguments)]
25pub trait KvLayer<B: Backend>: KvDtypeKind {
26 type Layer: Send + Sync;
28
29 fn alloc_paged(
31 max_blocks_per_seq: usize,
32 block_size: usize,
33 num_kv_heads: usize,
34 head_dim: usize,
35 ) -> Self::Layer;
36
37 fn alloc_contig(capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self::Layer;
39
40 fn len(layer: &Self::Layer) -> usize;
42 fn set_len(layer: &mut Self::Layer, new_len: usize);
43 fn capacity(layer: &Self::Layer) -> usize;
44 fn block_size(layer: &Self::Layer) -> usize;
45 fn num_kv_heads(layer: &Self::Layer) -> usize;
46 fn head_dim(layer: &Self::Layer) -> usize;
47 fn block_table(layer: &Self::Layer) -> Option<&B::Buffer>;
48 fn block_table_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer>;
49 fn context_lens(layer: &Self::Layer) -> Option<&B::Buffer>;
50 fn context_lens_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer>;
51 fn paged_block_indices(layer: &Self::Layer) -> &[u32];
52 fn paged_block_indices_mut(layer: &mut Self::Layer) -> &mut Vec<u32>;
53
54 fn is_paged(layer: &Self::Layer) -> bool {
55 Self::block_size(layer) > 0
56 }
57
58 fn paged_write(
62 ctx: &mut B::Context,
63 layer: &mut Self::Layer,
64 qkv: &B::Buffer,
65 q_norm_w: &B::Buffer,
66 k_norm_w: &B::Buffer,
67 cos: &B::Buffer,
68 sin: &B::Buffer,
69 q_out: &mut B::Buffer,
70 k_scratch: &mut B::Buffer,
71 v_scratch: &mut B::Buffer,
72 pool_k: &mut B::Buffer,
73 pool_v: &mut B::Buffer,
74 tokens: usize,
75 num_q_heads: usize,
76 num_kv_heads: usize,
77 head_dim: usize,
78 pos_offset: usize,
79 eps: f32,
80 qk_mode: i32,
81 ) -> Result<()>;
82
83 fn paged_decode_attention(
87 ctx: &mut B::Context,
88 layer: &mut Self::Layer,
89 q: &B::Buffer,
90 pool_k: &B::Buffer,
91 pool_v: &B::Buffer,
92 output: &mut B::Buffer,
93 num_q_heads: usize,
94 num_kv_heads: usize,
95 head_dim: usize,
96 final_kv_len: usize,
97 tokens: usize,
98 ) -> Result<()>;
99
100 fn contig_write(
104 _ctx: &mut B::Context,
105 _layer: &mut Self::Layer,
106 _qkv: &B::Buffer,
107 _q_norm_w: &B::Buffer,
108 _k_norm_w: &B::Buffer,
109 _cos: &B::Buffer,
110 _sin: &B::Buffer,
111 _q_out: &mut B::Buffer,
112 _k_scratch: &mut B::Buffer,
113 _v_scratch: &mut B::Buffer,
114 _q_buf: &mut B::Buffer,
115 _k_buf: &mut B::Buffer,
116 _v_buf: &mut B::Buffer,
117 _tokens: usize,
118 _num_q_heads: usize,
119 _num_kv_heads: usize,
120 _head_dim: usize,
121 _pos_offset: usize,
122 _eps: f32,
123 _qk_mode: i32,
124 ) -> Result<()> {
125 unimplemented!("contig_write: not supported for this K dtype")
126 }
127
128 fn contig_decode_attention(
130 _ctx: &mut B::Context,
131 _layer: &Self::Layer,
132 _q: &B::Buffer,
133 _output: &mut B::Buffer,
134 _attn_cfg: crate::backend::AttnConfig,
135 _tokens: usize,
136 _pos_offset: usize,
137 ) -> Result<()> {
138 unimplemented!("contig_decode_attention: not supported for this K dtype")
139 }
140}
141
142impl<B: Backend + crate::backend::BackendPagedKv> KvLayer<B> for KvFp16 {
147 type Layer = KvCache<B, KvFp16>;
148
149 fn alloc_paged(
150 max_blocks_per_seq: usize,
151 block_size: usize,
152 num_kv_heads: usize,
153 head_dim: usize,
154 ) -> Self::Layer {
155 let block_table = B::alloc_typed(crate::backend::Dtype::U32, max_blocks_per_seq);
156 let mut context_lens = B::alloc_typed(crate::backend::Dtype::U32, 1);
157 let mut bt_ctx = B::new_context();
158 B::write_typed::<u32>(&mut bt_ctx, &mut context_lens, &[0u32]);
159 B::sync(&mut bt_ctx);
160 KvCache {
161 k: B::alloc(1),
162 v: B::alloc(1),
163 len: 0,
164 capacity: max_blocks_per_seq * block_size,
165 num_kv_heads,
166 head_dim,
167 block_size,
168 block_table: Some(block_table),
169 context_lens: Some(context_lens),
170 paged_block_indices: Vec::new(),
171 _kv_dtype: std::marker::PhantomData,
172 }
173 }
174
175 fn alloc_contig(capacity: usize, num_kv_heads: usize, head_dim: usize) -> Self::Layer {
176 KvCache {
177 k: B::alloc(num_kv_heads * capacity * head_dim),
178 v: B::alloc(num_kv_heads * capacity * head_dim),
179 len: 0,
180 capacity,
181 num_kv_heads,
182 head_dim,
183 block_size: 0,
184 block_table: None,
185 context_lens: None,
186 paged_block_indices: Vec::new(),
187 _kv_dtype: std::marker::PhantomData,
188 }
189 }
190
191 fn len(layer: &Self::Layer) -> usize {
192 layer.len
193 }
194 fn set_len(layer: &mut Self::Layer, new_len: usize) {
195 layer.len = new_len;
196 }
197 fn capacity(layer: &Self::Layer) -> usize {
198 layer.capacity
199 }
200 fn block_size(layer: &Self::Layer) -> usize {
201 layer.block_size
202 }
203 fn num_kv_heads(layer: &Self::Layer) -> usize {
204 layer.num_kv_heads
205 }
206 fn head_dim(layer: &Self::Layer) -> usize {
207 layer.head_dim
208 }
209 fn block_table(layer: &Self::Layer) -> Option<&B::Buffer> {
210 layer.block_table.as_ref()
211 }
212 fn block_table_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer> {
213 layer.block_table.as_mut()
214 }
215 fn context_lens(layer: &Self::Layer) -> Option<&B::Buffer> {
216 layer.context_lens.as_ref()
217 }
218 fn context_lens_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer> {
219 layer.context_lens.as_mut()
220 }
221 fn paged_block_indices(layer: &Self::Layer) -> &[u32] {
222 &layer.paged_block_indices
223 }
224 fn paged_block_indices_mut(layer: &mut Self::Layer) -> &mut Vec<u32> {
225 &mut layer.paged_block_indices
226 }
227
228 fn paged_write(
229 ctx: &mut B::Context,
230 layer: &mut Self::Layer,
231 qkv: &B::Buffer,
232 q_norm_w: &B::Buffer,
233 k_norm_w: &B::Buffer,
234 cos: &B::Buffer,
235 sin: &B::Buffer,
236 q_out: &mut B::Buffer,
237 _k_scratch: &mut B::Buffer,
238 _v_scratch: &mut B::Buffer,
239 pool_k: &mut B::Buffer,
240 pool_v: &mut B::Buffer,
241 tokens: usize,
242 num_q_heads: usize,
243 num_kv_heads: usize,
244 head_dim: usize,
245 pos_offset: usize,
246 eps: f32,
247 qk_mode: i32,
248 ) -> Result<()> {
249 let block_size = layer.block_size;
250 let cache_len_before = layer.len;
251 let num_blocks_per_seq = layer.capacity / block_size;
252 let bt = layer
253 .block_table
254 .as_ref()
255 .ok_or_else(|| FerrumError::model("FP16 paged_write: missing block_table"))?;
256 B::split_qkv_norm_rope_into_paged_cache(
257 ctx,
258 qkv,
259 0,
260 q_norm_w,
261 k_norm_w,
262 cos,
263 sin,
264 q_out,
265 0,
266 pool_k,
267 pool_v,
268 bt,
269 tokens,
270 num_q_heads,
271 num_kv_heads,
272 head_dim,
273 pos_offset,
274 eps,
275 qk_mode,
276 cache_len_before,
277 block_size,
278 num_blocks_per_seq,
279 )
280 }
281
282 fn paged_decode_attention(
283 ctx: &mut B::Context,
284 layer: &mut Self::Layer,
285 q: &B::Buffer,
286 pool_k: &B::Buffer,
287 pool_v: &B::Buffer,
288 output: &mut B::Buffer,
289 num_q_heads: usize,
290 num_kv_heads: usize,
291 head_dim: usize,
292 final_kv_len: usize,
293 tokens: usize,
294 ) -> Result<()> {
295 let block_size = layer.block_size;
296 let num_blocks_per_seq = layer.capacity / block_size;
297 let bt_ptr = layer
298 .block_table
299 .as_ref()
300 .ok_or_else(|| FerrumError::model("FP16 paged_decode: missing block_table"))?
301 as *const B::Buffer;
302 let cl_buf = layer
303 .context_lens
304 .as_mut()
305 .ok_or_else(|| FerrumError::model("FP16 paged_decode: missing context_lens"))?;
306 B::write_typed::<u32>(ctx, cl_buf, &[final_kv_len as u32]);
307 let bt = unsafe { &*bt_ptr };
309 let cl = layer.context_lens.as_ref().unwrap();
310 B::paged_decode_attention(
311 ctx,
312 q,
313 pool_k,
314 pool_v,
315 output,
316 bt,
317 cl,
318 1,
319 num_q_heads,
320 num_kv_heads,
321 head_dim,
322 block_size,
323 num_blocks_per_seq,
324 tokens,
325 )
326 }
327
328 fn contig_write(
329 ctx: &mut B::Context,
330 layer: &mut Self::Layer,
331 qkv: &B::Buffer,
332 q_norm_w: &B::Buffer,
333 k_norm_w: &B::Buffer,
334 cos: &B::Buffer,
335 sin: &B::Buffer,
336 q_out: &mut B::Buffer,
337 k_scratch: &mut B::Buffer,
338 v_scratch: &mut B::Buffer,
339 q_buf: &mut B::Buffer,
340 k_buf: &mut B::Buffer,
341 v_buf: &mut B::Buffer,
342 tokens: usize,
343 num_q_heads: usize,
344 num_kv_heads: usize,
345 head_dim: usize,
346 pos_offset: usize,
347 eps: f32,
348 qk_mode: i32,
349 ) -> Result<()> {
350 let cache_len_before = layer.len;
351 let cache_capacity = layer.capacity;
352 let used_into_cache = B::split_qkv_norm_rope_into_cache(
353 ctx,
354 qkv,
355 q_norm_w,
356 k_norm_w,
357 cos,
358 sin,
359 q_out,
360 &mut layer.k,
361 &mut layer.v,
362 tokens,
363 num_q_heads,
364 num_kv_heads,
365 head_dim,
366 pos_offset,
367 eps,
368 qk_mode,
369 cache_len_before,
370 cache_capacity,
371 )
372 .is_ok();
373 if used_into_cache {
374 return Ok(());
375 }
376 let used_fused_qkv = B::split_qkv_norm_rope(
377 ctx,
378 qkv,
379 q_norm_w,
380 k_norm_w,
381 cos,
382 sin,
383 q_out,
384 k_scratch,
385 v_scratch,
386 tokens,
387 num_q_heads,
388 num_kv_heads,
389 head_dim,
390 pos_offset,
391 eps,
392 qk_mode,
393 )
394 .is_ok();
395 if !used_fused_qkv {
396 let q_dim = num_q_heads * head_dim;
397 let kv_dim = num_kv_heads * head_dim;
398 B::split_qkv(ctx, qkv, q_buf, k_buf, v_buf, tokens, q_dim, kv_dim);
399 B::qk_norm_rope(
400 ctx,
401 q_buf,
402 q_norm_w,
403 cos,
404 sin,
405 q_out,
406 tokens,
407 num_q_heads,
408 head_dim,
409 pos_offset,
410 eps,
411 qk_mode,
412 );
413 B::qk_norm_rope(
414 ctx,
415 k_buf,
416 k_norm_w,
417 cos,
418 sin,
419 k_scratch,
420 tokens,
421 num_kv_heads,
422 head_dim,
423 pos_offset,
424 eps,
425 qk_mode,
426 );
427 B::qk_norm_rope(
428 ctx,
429 v_buf,
430 q_norm_w,
431 cos,
432 sin,
433 v_scratch,
434 tokens,
435 num_kv_heads,
436 head_dim,
437 pos_offset,
438 eps,
439 0,
440 );
441 }
442 B::kv_cache_append_head_major(
443 ctx,
444 &mut layer.k,
445 &mut layer.v,
446 cache_len_before,
447 cache_capacity,
448 k_scratch,
449 v_scratch,
450 tokens,
451 num_kv_heads,
452 head_dim,
453 );
454 Ok(())
455 }
456
457 fn contig_decode_attention(
458 ctx: &mut B::Context,
459 layer: &Self::Layer,
460 q: &B::Buffer,
461 output: &mut B::Buffer,
462 attn_cfg: crate::backend::AttnConfig,
463 tokens: usize,
464 pos_offset: usize,
465 ) -> Result<()> {
466 let kv_len = layer.len;
467 B::flash_attention(
468 ctx, q, &layer.k, &layer.v, output, 1, tokens, kv_len, pos_offset, &attn_cfg,
469 );
470 Ok(())
471 }
472}
473
474impl<B: Backend + BackendInt8KvOps> KvLayer<B> for KvInt8 {
479 type Layer = KvCacheQuant<B, KvInt8>;
480
481 fn alloc_paged(
482 max_blocks_per_seq: usize,
483 block_size: usize,
484 num_kv_heads: usize,
485 head_dim: usize,
486 ) -> Self::Layer {
487 B::alloc_paged_int8_layer(max_blocks_per_seq, block_size, num_kv_heads, head_dim)
488 }
489
490 fn alloc_contig(_capacity: usize, _num_kv_heads: usize, _head_dim: usize) -> Self::Layer {
491 panic!("KvInt8::alloc_contig: INT8 KV is paged-only")
492 }
493
494 fn len(layer: &Self::Layer) -> usize {
495 layer.len
496 }
497 fn set_len(layer: &mut Self::Layer, new_len: usize) {
498 layer.len = new_len;
499 }
500 fn capacity(layer: &Self::Layer) -> usize {
501 layer.capacity
502 }
503 fn block_size(layer: &Self::Layer) -> usize {
504 layer.block_size
505 }
506 fn num_kv_heads(layer: &Self::Layer) -> usize {
507 layer.num_kv_heads
508 }
509 fn head_dim(layer: &Self::Layer) -> usize {
510 layer.head_dim
511 }
512 fn block_table(layer: &Self::Layer) -> Option<&B::Buffer> {
513 layer.block_table.as_ref()
514 }
515 fn block_table_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer> {
516 layer.block_table.as_mut()
517 }
518 fn context_lens(layer: &Self::Layer) -> Option<&B::Buffer> {
519 layer.context_lens.as_ref()
520 }
521 fn context_lens_mut(layer: &mut Self::Layer) -> Option<&mut B::Buffer> {
522 layer.context_lens.as_mut()
523 }
524 fn paged_block_indices(layer: &Self::Layer) -> &[u32] {
525 &layer.paged_block_indices
526 }
527 fn paged_block_indices_mut(layer: &mut Self::Layer) -> &mut Vec<u32> {
528 &mut layer.paged_block_indices
529 }
530
531 fn paged_write(
532 ctx: &mut B::Context,
533 layer: &mut Self::Layer,
534 qkv: &B::Buffer,
535 q_norm_w: &B::Buffer,
536 k_norm_w: &B::Buffer,
537 cos: &B::Buffer,
538 sin: &B::Buffer,
539 q_out: &mut B::Buffer,
540 k_scratch: &mut B::Buffer,
541 v_scratch: &mut B::Buffer,
542 _pool_k: &mut B::Buffer,
543 _pool_v: &mut B::Buffer,
544 tokens: usize,
545 num_q_heads: usize,
546 num_kv_heads: usize,
547 head_dim: usize,
548 pos_offset: usize,
549 eps: f32,
550 qk_mode: i32,
551 ) -> Result<()> {
552 B::split_qkv_norm_rope(
554 ctx,
555 qkv,
556 q_norm_w,
557 k_norm_w,
558 cos,
559 sin,
560 q_out,
561 k_scratch,
562 v_scratch,
563 tokens,
564 num_q_heads,
565 num_kv_heads,
566 head_dim,
567 pos_offset,
568 eps,
569 qk_mode,
570 )?;
571 let cache_len_before = layer.len;
576 let block_size = layer.block_size;
577 let paged_indices: Vec<u32> = layer.paged_block_indices.clone();
580 B::int8_kv_append_paged(
581 ctx,
582 k_scratch,
583 v_scratch,
584 &mut layer.k,
585 &mut layer.v,
586 &mut layer.k_scales,
587 &mut layer.v_scales,
588 &paged_indices,
589 cache_len_before,
590 tokens,
591 block_size,
592 num_kv_heads,
593 head_dim,
594 )
595 }
596
597 fn paged_decode_attention(
598 ctx: &mut B::Context,
599 layer: &mut Self::Layer,
600 q: &B::Buffer,
601 _pool_k: &B::Buffer,
602 _pool_v: &B::Buffer,
603 output: &mut B::Buffer,
604 num_q_heads: usize,
605 num_kv_heads: usize,
606 head_dim: usize,
607 final_kv_len: usize,
608 _tokens: usize,
609 ) -> Result<()> {
610 let block_size = layer.block_size;
611 let cl_buf = layer
612 .context_lens
613 .as_mut()
614 .ok_or_else(|| FerrumError::model("INT8 paged_decode: missing context_lens"))?;
615 B::write_typed::<u32>(ctx, cl_buf, &[final_kv_len as u32]);
616 let bt = layer
617 .block_table
618 .as_ref()
619 .ok_or_else(|| FerrumError::model("INT8 paged_decode: missing block_table"))?;
620 let scale = (head_dim as f32).sqrt().recip();
621 B::int8_paged_decode_attention(
622 ctx,
623 q,
624 &layer.k,
625 &layer.v,
626 &layer.k_scales,
627 &layer.v_scales,
628 bt,
629 output,
630 num_q_heads,
631 num_kv_heads,
632 head_dim,
633 final_kv_len,
634 block_size,
635 scale,
636 )
637 }
638}