Skip to main content

DecoderOnlyLLM

Trait DecoderOnlyLLM 

Source
pub trait DecoderOnlyLLM: Send + Sync {
    // Required methods
    fn config(&self) -> &LlmRuntimeConfig;
    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32>;
    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32>;
    fn release(&mut self, cache_id: &str);

    // Provided methods
    fn prepare(&mut self, cache_id: &str, max_tokens: usize) { ... }
    fn kv_capacity(&self) -> usize { ... }
    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> { ... }
    fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> { ... }
    fn truncate_kv(&mut self, cache_id: &str, new_len: usize) { ... }
    fn reset(&mut self) { ... }
}
Expand description

A decoder-only language model.

Contract:

  • prefill processes a batch of prompt tokens and returns logits for the last token, along with initializing whatever KV cache the model maintains internally (keyed by cache_id).
  • decode processes a single generated token at position pos and returns logits for the next step.
  • release frees the KV cache for a completed sequence.

Today the model owns its KV cache. Integration with ferrum-kv’s paged KV manager is a Phase D concern; the trait is kept minimal so it can evolve then without a full refactor.

Required Methods§

Source

fn config(&self) -> &LlmRuntimeConfig

Runtime-facing configuration.

Source

fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32>

Prefill the model with a prompt. Returns [vocab_size] logits for the last prompt token.

Source

fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32>

Advance the model by one generated token. pos is the position of token in the sequence (number of tokens already consumed so far). Returns [vocab_size] logits for the next step.

Source

fn release(&mut self, cache_id: &str)

Release the KV cache for a completed sequence.

Provided Methods§

Source

fn prepare(&mut self, cache_id: &str, max_tokens: usize)

Hint that an upcoming prefill / decode sequence on cache_id will have at most max_tokens tokens per call. Lets the model eagerly grow its internal scratch buffers AND allocate the KV cache for cache_id so the first real prefill doesn’t have to allocate them on the hot path.

Without this, on Qwen3-MoE’s first prefill the timer captures: • ~25 scratch MTLBuffers (residual / qkv / head-major / MoE staging / batch-logits) — ~80-150 ms total alloc • ~96 KV-cache MTLBuffers (K and V × 48 layers) — another ~100-500 ms total alloc

Combined that’s the ~350 ms fixed overhead that made pp50 numbers look 40% slower than pp512 for the same per-token compute.

Default no-op — backends without resizable buffers ignore it.

Source

fn kv_capacity(&self) -> usize

Per-cache KV capacity in tokens — the maximum sequence length any single cache_id can grow to before prefill / decode would overflow the pre-allocated K/V buffers.

Honours FERRUM_KV_CAPACITY and clamps to the model’s declared max_seq_len. Callers (REPL, HTTP server, schedulers) should pre-check this before extending a sequence; the model panics on append-side overflow rather than silently corrupt the cache.

Default returns config().max_seq_len. Models that allocate a smaller window (most do, capped by FERRUM_KV_CAPACITY or the 4096 default in ensure_kv) override this to surface the real budget.

Source

fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>>

Decode multiple concurrent requests in a single forward pass.

Each entry is (cache_id, token, pos) — per-request state. Returns one [vocab_size] logits vec per request in the SAME order.

Default implementation loops decode sequentially. Backends that implement true batched decode (one GEMM with m=batch, per-item attention loop) override for concurrency speedup.

Source

fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32>

Multi-position decode-verify: run a single forward over tokens starting at the current KV end, append their K/V in place, and return seq_len * vocab_size logits (row-major, position-first).

Used by speculative decoding to collect N+1 verification logits in one target pass instead of N+1 sequential decodes.

Default falls back to a decode loop — slower but correct, lets minor backends not reimplement the primitive.

Source

fn truncate_kv(&mut self, cache_id: &str, new_len: usize)

Truncate the KV cache for cache_id back to new_len positions. Used by speculative decoding on rejection — roll draft/target KV back to the last accepted position before the next iteration.

Default implementation is a panic so backends that don’t support rollback fail loudly; implementations override this.

Source

fn reset(&mut self)

Drop all cached state (useful for tests and hot-reload).

Implementors§