Skip to main content

ferrum_models/common/
llm.rs

1//! `DecoderOnlyLLM` trait — the "model family" interface that every
2//! decoder-only language model (Qwen3 / Llama / Mistral / DeepSeek / ...)
3//! implements, independent of backend and weight format.
4//!
5//! `LlmExecutor` (living in `ferrum-engine`) holds a `Box<dyn DecoderOnlyLLM>`
6//! and adapts it to the `ModelExecutor` trait that the scheduler calls.
7
8/// Runtime configuration every decoder-only LLM must expose.
9///
10/// This is the *execution-facing* config — the bare minimum the surrounding
11/// engine needs (KV cache sizing, sampler vocab bounds, scheduler quotas).
12/// It deliberately does not include architecture details like `num_heads`
13/// or `intermediate_size`; those stay private to the model implementation.
14#[derive(Clone, Debug)]
15pub struct LlmRuntimeConfig {
16    pub hidden_size: usize,
17    pub num_layers: usize,
18    pub num_kv_heads: usize,
19    pub head_dim: usize,
20    pub vocab_size: usize,
21    pub max_seq_len: usize,
22}
23
24/// A decoder-only language model.
25///
26/// Contract:
27/// - `prefill` processes a batch of prompt tokens and returns logits for the
28///   *last* token, along with initializing whatever KV cache the model
29///   maintains internally (keyed by `cache_id`).
30/// - `decode` processes a single generated token at position `pos` and
31///   returns logits for the next step.
32/// - `release` frees the KV cache for a completed sequence.
33///
34/// Today the model owns its KV cache. Integration with `ferrum-kv`'s paged
35/// KV manager is a Phase D concern; the trait is kept minimal so it can
36/// evolve then without a full refactor.
37pub trait DecoderOnlyLLM: Send + Sync {
38    /// Runtime-facing configuration.
39    fn config(&self) -> &LlmRuntimeConfig;
40
41    /// Optional model-level cache metrics.
42    ///
43    /// Models with real paged-KV prefix reuse override this so the executor
44    /// and HTTP server can distinguish true KV reuse from product-level
45    /// prompt observability.
46    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
47        None
48    }
49
50    /// Optional runtime LoRA metrics.
51    fn lora_metrics_snapshot(&self) -> Option<serde_json::Value> {
52        None
53    }
54
55    /// Bind or clear a startup LoRA adapter for a model-side KV cache id.
56    ///
57    /// The executor calls this before prefill/decode based on request
58    /// metadata. Models that implement real LoRA inference override it and
59    /// keep the adapter scoped to `cache_id`; unsupported models return an
60    /// explicit error instead of silently serving the base model.
61    fn set_lora_adapter_for_cache(
62        &mut self,
63        cache_id: &str,
64        adapter: Option<crate::lora::ActiveLoraAdapter>,
65    ) -> std::result::Result<(), ferrum_types::FerrumError> {
66        let _ = cache_id;
67        if let Some(adapter) = adapter {
68            return Err(ferrum_types::FerrumError::unsupported(format!(
69                "LoRA inference is not supported by this model/backend for adapter {} at {}",
70                adapter.name,
71                adapter.path.display()
72            )));
73        }
74        Ok(())
75    }
76
77    /// Hint that an upcoming `prefill` / `decode` sequence on
78    /// `cache_id` will have at most `max_tokens` tokens per call. Lets
79    /// the model eagerly grow its internal scratch buffers AND allocate
80    /// the KV cache for `cache_id` so the first real `prefill` doesn't
81    /// have to allocate them on the hot path.
82    ///
83    /// Without this, on Qwen3-MoE's first prefill the timer captures:
84    ///   • ~25 scratch MTLBuffers (residual / qkv / head-major / MoE
85    ///     staging / batch-logits) — ~80-150 ms total alloc
86    ///   • ~96 KV-cache MTLBuffers (K and V × 48 layers) — another
87    ///     ~100-500 ms total alloc
88    ///
89    /// Combined that's the ~350 ms fixed overhead that made pp50 numbers
90    /// look 40% slower than pp512 for the same per-token compute.
91    ///
92    /// Default no-op — backends without resizable buffers ignore it.
93    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
94        let _ = (cache_id, max_tokens);
95    }
96
97    /// Per-cache KV capacity in tokens — the maximum sequence length any
98    /// single `cache_id` can grow to before `prefill` / `decode` would
99    /// overflow the pre-allocated K/V buffers.
100    ///
101    /// Honours `FERRUM_KV_CAPACITY` and clamps to the model's declared
102    /// `max_seq_len`. Callers (REPL, HTTP server, schedulers) should
103    /// pre-check this before extending a sequence; the model panics on
104    /// append-side overflow rather than silently corrupt the cache.
105    ///
106    /// Default returns `config().max_seq_len`. Models that allocate a
107    /// smaller window (most do, capped by `FERRUM_KV_CAPACITY` or the
108    /// 4096 default in `ensure_kv`) override this to surface the real
109    /// budget.
110    fn kv_capacity(&self) -> usize {
111        self.config().max_seq_len
112    }
113
114    /// Prefill the model with a prompt. Returns `[vocab_size]` logits for
115    /// the last prompt token.
116    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32>;
117
118    /// Advance the model by one generated token. `pos` is the position of
119    /// `token` in the sequence (number of tokens already consumed so far).
120    /// Returns `[vocab_size]` logits for the next step.
121    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32>;
122
123    /// Decode multiple concurrent requests in a single forward pass.
124    ///
125    /// Each entry is `(cache_id, token, pos)` — per-request state. Returns
126    /// one `[vocab_size]` logits vec per request in the SAME order.
127    ///
128    /// Default implementation loops `decode` sequentially. Backends that
129    /// implement true batched decode (one GEMM with m=batch, per-item
130    /// attention loop) override for concurrency speedup.
131    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
132        batch
133            .iter()
134            .map(|(cid, tok, p)| self.decode(cid, *tok, *p))
135            .collect()
136    }
137
138    fn decode_batch_with_full_logits(
139        &mut self,
140        batch: &[(String, u32, u32)],
141        _force_full_logits: bool,
142    ) -> Vec<Vec<f32>> {
143        self.decode_batch(batch)
144    }
145
146    /// Multi-position decode-verify: run a single forward over `tokens`
147    /// starting at the current KV end, append their K/V in place, and
148    /// return `seq_len * vocab_size` logits (row-major, position-first).
149    ///
150    /// Used by speculative decoding to collect N+1 verification logits
151    /// in one target pass instead of N+1 sequential decodes.
152    ///
153    /// Default falls back to a decode loop — slower but correct, lets
154    /// minor backends not reimplement the primitive.
155    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
156        let mut out = Vec::with_capacity(tokens.len() * self.config().vocab_size);
157        // cache.len before any decode in this batch — we can derive per-token
158        // position from it. Backends override this default for real batching.
159        let start_pos = 0u32; // placeholder; real impls know their own state
160        for (i, &tok) in tokens.iter().enumerate() {
161            out.extend_from_slice(&self.decode(cache_id, tok, start_pos + i as u32));
162        }
163        out
164    }
165
166    /// Unified mixed-batch forward (chunked-prefill API).
167    ///
168    /// Accepts a heterogeneous batch where each item is `(cache_id,
169    /// q_tokens, pos_offset, is_final_chunk)`:
170    /// - `q_tokens.len() == 1` & `is_final_chunk == true` → decode step
171    /// - `q_tokens.len() >= 1` & `is_final_chunk == true` → final
172    ///   prefill chunk (returns logits for sampling)
173    /// - `q_tokens.len() >= 1` & `is_final_chunk == false` → intermediate
174    ///   prefill chunk (advances KV state, returns None)
175    ///
176    /// `pos_offset` is the absolute KV position of the first q-token
177    /// for that sequence (0 for fresh prefill, prior `kv_len` for
178    /// continuing chunks or decode steps).
179    ///
180    /// Returns one entry per `items[i]`: `Some(logits)` iff
181    /// `is_final_chunk == true`, else `None`.
182    ///
183    /// Default implementation: returns `Err(unsupported)`. Concrete
184    /// models that support a true unified forward (single forward pass
185    /// over the concatenated `[M_total, hidden]` tensor + varlen
186    /// attention) override this. The engine's caller (`LlmExecutor`)
187    /// recognises the unsupported error and falls back to splitting
188    /// the batch into per-item `prefill()` and a single `decode_batch()`
189    /// — behaviour-preserving but doesn't get the chunked-prefill perf
190    /// win until the model exposes a real unified path.
191    #[allow(clippy::type_complexity)]
192    fn unified_forward(
193        &mut self,
194        _items: &[(String, Vec<u32>, usize, bool)],
195    ) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
196        Err(ferrum_types::FerrumError::unsupported(
197            "unified_forward not implemented for this model",
198        ))
199    }
200
201    /// Release the KV cache for a completed sequence.
202    fn release(&mut self, cache_id: &str);
203
204    /// Truncate the KV cache for `cache_id` back to `new_len` positions.
205    /// Used by speculative decoding on rejection — roll draft/target KV
206    /// back to the last accepted position before the next iteration.
207    ///
208    /// Default implementation is a panic so backends that don't support
209    /// rollback fail loudly; implementations override this.
210    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
211        let _ = (cache_id, new_len);
212        panic!("truncate_kv not implemented for this DecoderOnlyLLM");
213    }
214
215    /// Drop all cached state (useful for tests and hot-reload).
216    fn reset(&mut self) {}
217}