Skip to main content

ferrum_models/common/
llm.rs

1//! `DecoderOnlyLLM` trait — the "model family" interface that every
2//! decoder-only language model (Qwen3 / Llama / Mistral / DeepSeek / ...)
3//! implements, independent of backend and weight format.
4//!
5//! `LlmExecutor` (living in `ferrum-engine`) holds a `Box<dyn DecoderOnlyLLM>`
6//! and adapts it to the `ModelExecutor` trait that the scheduler calls.
7
8/// Runtime configuration every decoder-only LLM must expose.
9///
10/// This is the *execution-facing* config — the bare minimum the surrounding
11/// engine needs (KV cache sizing, sampler vocab bounds, scheduler quotas).
12/// It deliberately does not include architecture details like `num_heads`
13/// or `intermediate_size`; those stay private to the model implementation.
14#[derive(Clone, Debug)]
15pub struct LlmRuntimeConfig {
16    pub hidden_size: usize,
17    pub num_layers: usize,
18    pub num_kv_heads: usize,
19    pub head_dim: usize,
20    pub vocab_size: usize,
21    pub max_seq_len: usize,
22}
23
24/// A decoder-only language model.
25///
26/// Contract:
27/// - `prefill` processes a batch of prompt tokens and returns logits for the
28///   *last* token, along with initializing whatever KV cache the model
29///   maintains internally (keyed by `cache_id`).
30/// - `decode` processes a single generated token at position `pos` and
31///   returns logits for the next step.
32/// - `release` frees the KV cache for a completed sequence.
33///
34/// Today the model owns its KV cache. Integration with `ferrum-kv`'s paged
35/// KV manager is a Phase D concern; the trait is kept minimal so it can
36/// evolve then without a full refactor.
37pub trait DecoderOnlyLLM: Send + Sync {
38    /// Runtime-facing configuration.
39    fn config(&self) -> &LlmRuntimeConfig;
40
41    /// Prefill the model with a prompt. Returns `[vocab_size]` logits for
42    /// the last prompt token.
43    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32>;
44
45    /// Advance the model by one generated token. `pos` is the position of
46    /// `token` in the sequence (number of tokens already consumed so far).
47    /// Returns `[vocab_size]` logits for the next step.
48    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32>;
49
50    /// Decode multiple concurrent requests in a single forward pass.
51    ///
52    /// Each entry is `(cache_id, token, pos)` — per-request state. Returns
53    /// one `[vocab_size]` logits vec per request in the SAME order.
54    ///
55    /// Default implementation loops `decode` sequentially. Backends that
56    /// implement true batched decode (one GEMM with m=batch, per-item
57    /// attention loop) override for concurrency speedup.
58    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
59        batch
60            .iter()
61            .map(|(cid, tok, p)| self.decode(cid, *tok, *p))
62            .collect()
63    }
64
65    /// Release the KV cache for a completed sequence.
66    fn release(&mut self, cache_id: &str);
67
68    /// Drop all cached state (useful for tests and hot-reload).
69    fn reset(&mut self) {}
70}