ferrum_models/common/llm.rs
1//! `DecoderOnlyLLM` trait — the "model family" interface that every
2//! decoder-only language model (Qwen3 / Llama / Mistral / DeepSeek / ...)
3//! implements, independent of backend and weight format.
4//!
5//! `LlmExecutor` (living in `ferrum-engine`) holds a `Box<dyn DecoderOnlyLLM>`
6//! and adapts it to the `ModelExecutor` trait that the scheduler calls.
7
8/// Runtime configuration every decoder-only LLM must expose.
9///
10/// This is the *execution-facing* config — the bare minimum the surrounding
11/// engine needs (KV cache sizing, sampler vocab bounds, scheduler quotas).
12/// It deliberately does not include architecture details like `num_heads`
13/// or `intermediate_size`; those stay private to the model implementation.
14#[derive(Clone, Debug)]
15pub struct LlmRuntimeConfig {
16 pub hidden_size: usize,
17 pub num_layers: usize,
18 pub num_kv_heads: usize,
19 pub head_dim: usize,
20 pub vocab_size: usize,
21 pub max_seq_len: usize,
22}
23
24/// A decoder-only language model.
25///
26/// Contract:
27/// - `prefill` processes a batch of prompt tokens and returns logits for the
28/// *last* token, along with initializing whatever KV cache the model
29/// maintains internally (keyed by `cache_id`).
30/// - `decode` processes a single generated token at position `pos` and
31/// returns logits for the next step.
32/// - `release` frees the KV cache for a completed sequence.
33///
34/// Today the model owns its KV cache. Integration with `ferrum-kv`'s paged
35/// KV manager is a Phase D concern; the trait is kept minimal so it can
36/// evolve then without a full refactor.
37pub trait DecoderOnlyLLM: Send + Sync {
38 /// Runtime-facing configuration.
39 fn config(&self) -> &LlmRuntimeConfig;
40
41 /// Hint that an upcoming `prefill` / `decode` sequence on
42 /// `cache_id` will have at most `max_tokens` tokens per call. Lets
43 /// the model eagerly grow its internal scratch buffers AND allocate
44 /// the KV cache for `cache_id` so the first real `prefill` doesn't
45 /// have to allocate them on the hot path.
46 ///
47 /// Without this, on Qwen3-MoE's first prefill the timer captures:
48 /// • ~25 scratch MTLBuffers (residual / qkv / head-major / MoE
49 /// staging / batch-logits) — ~80-150 ms total alloc
50 /// • ~96 KV-cache MTLBuffers (K and V × 48 layers) — another
51 /// ~100-500 ms total alloc
52 ///
53 /// Combined that's the ~350 ms fixed overhead that made pp50 numbers
54 /// look 40% slower than pp512 for the same per-token compute.
55 ///
56 /// Default no-op — backends without resizable buffers ignore it.
57 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
58 let _ = (cache_id, max_tokens);
59 }
60
61 /// Per-cache KV capacity in tokens — the maximum sequence length any
62 /// single `cache_id` can grow to before `prefill` / `decode` would
63 /// overflow the pre-allocated K/V buffers.
64 ///
65 /// Honours `FERRUM_KV_CAPACITY` and clamps to the model's declared
66 /// `max_seq_len`. Callers (REPL, HTTP server, schedulers) should
67 /// pre-check this before extending a sequence; the model panics on
68 /// append-side overflow rather than silently corrupt the cache.
69 ///
70 /// Default returns `config().max_seq_len`. Models that allocate a
71 /// smaller window (most do, capped by `FERRUM_KV_CAPACITY` or the
72 /// 4096 default in `ensure_kv`) override this to surface the real
73 /// budget.
74 fn kv_capacity(&self) -> usize {
75 self.config().max_seq_len
76 }
77
78 /// Prefill the model with a prompt. Returns `[vocab_size]` logits for
79 /// the last prompt token.
80 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32>;
81
82 /// Advance the model by one generated token. `pos` is the position of
83 /// `token` in the sequence (number of tokens already consumed so far).
84 /// Returns `[vocab_size]` logits for the next step.
85 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32>;
86
87 /// Decode multiple concurrent requests in a single forward pass.
88 ///
89 /// Each entry is `(cache_id, token, pos)` — per-request state. Returns
90 /// one `[vocab_size]` logits vec per request in the SAME order.
91 ///
92 /// Default implementation loops `decode` sequentially. Backends that
93 /// implement true batched decode (one GEMM with m=batch, per-item
94 /// attention loop) override for concurrency speedup.
95 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
96 batch
97 .iter()
98 .map(|(cid, tok, p)| self.decode(cid, *tok, *p))
99 .collect()
100 }
101
102 /// Multi-position decode-verify: run a single forward over `tokens`
103 /// starting at the current KV end, append their K/V in place, and
104 /// return `seq_len * vocab_size` logits (row-major, position-first).
105 ///
106 /// Used by speculative decoding to collect N+1 verification logits
107 /// in one target pass instead of N+1 sequential decodes.
108 ///
109 /// Default falls back to a decode loop — slower but correct, lets
110 /// minor backends not reimplement the primitive.
111 fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
112 let mut out = Vec::with_capacity(tokens.len() * self.config().vocab_size);
113 // cache.len before any decode in this batch — we can derive per-token
114 // position from it. Backends override this default for real batching.
115 let start_pos = 0u32; // placeholder; real impls know their own state
116 for (i, &tok) in tokens.iter().enumerate() {
117 out.extend_from_slice(&self.decode(cache_id, tok, start_pos + i as u32));
118 }
119 out
120 }
121
122 /// Release the KV cache for a completed sequence.
123 fn release(&mut self, cache_id: &str);
124
125 /// Truncate the KV cache for `cache_id` back to `new_len` positions.
126 /// Used by speculative decoding on rejection — roll draft/target KV
127 /// back to the last accepted position before the next iteration.
128 ///
129 /// Default implementation is a panic so backends that don't support
130 /// rollback fail loudly; implementations override this.
131 fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
132 let _ = (cache_id, new_len);
133 panic!("truncate_kv not implemented for this DecoderOnlyLLM");
134 }
135
136 /// Drop all cached state (useful for tests and hot-reload).
137 fn reset(&mut self) {}
138}