Skip to main content

ferrum_interfaces/
decode_backend.rs

1//! Decode backend abstraction.
2//!
3//! Different backends (CUDA, Metal, CPU/Candle) implement `DecodeBackend`
4//! to provide optimized decode execution. The `GenericDecodeExecutor` uses
5//! a DecodeBackend to execute the decode hot path.
6
7use crate::tensor::TensorRef;
8use ferrum_types::Result;
9
10/// Decode-phase execution backend.
11///
12/// Implements the actual computation for single-token decode steps.
13/// Different backends optimize for different hardware:
14/// - `CudaDecodeBackend`: cuBLAS + custom CUDA kernels, pre-allocated buffers
15/// - `MetalDecodeBackend`: Metal compute shaders
16/// - `CandleDecodeBackend`: candle tensor ops (CPU/fallback)
17///
18/// The backend is initialized with model weights and manages its own
19/// internal state (KV cache, buffers, cuBLAS handles, etc.).
20pub trait DecodeBackend: Send + Sync {
21    /// Execute a single decode step: one token in, logits out.
22    ///
23    /// - `token_id`: the input token
24    /// - `position`: sequence position (for RoPE)
25    /// - `cache_key`: identifies the sequence's KV cache
26    ///
27    /// Returns logits as a TensorRef [1, 1, vocab_size].
28    fn decode_step(&mut self, token_id: u32, position: usize, cache_key: &str)
29        -> Result<TensorRef>;
30
31    /// Initialize KV cache for a new sequence from prefill data.
32    ///
33    /// Called after prefill (which runs through the model's forward pass)
34    /// to hand off the KV cache to the decode backend.
35    ///
36    /// `kv_data`: per-layer (K, V) tensor pairs from the prefill pass.
37    /// `prefill_len`: number of tokens in the prefill.
38    fn init_kv_cache(
39        &mut self,
40        cache_key: &str,
41        kv_data: Vec<(TensorRef, TensorRef)>,
42        prefill_len: usize,
43    ) -> Result<()>;
44
45    /// Check if KV cache exists for a sequence.
46    fn has_kv_cache(&self, cache_key: &str) -> bool;
47
48    /// Release KV cache for a completed sequence.
49    fn release_kv_cache(&mut self, cache_key: &str);
50
51    /// Human-readable backend name (for logging).
52    fn name(&self) -> &str;
53}