Skip to main content

ferrum_interfaces/
decode_backend.rs

1//! Decode backend abstraction.
2//!
3//! Different backends (CUDA, Metal, CPU/Candle) implement `DecodeBackend`
4//! to provide optimized decode execution. The `GenericDecodeExecutor` uses
5//! a DecodeBackend to execute the decode hot path.
6
7use crate::tensor::TensorRef;
8use crate::transformer::TransformerWeights;
9use ferrum_types::Result;
10
11/// Decode-phase execution backend.
12///
13/// Implements the actual computation for single-token decode steps.
14/// Different backends optimize for different hardware:
15/// - `CudaDecodeBackend`: cuBLAS + custom CUDA kernels, pre-allocated buffers
16/// - `MetalDecodeBackend`: Metal compute shaders
17/// - `CandleDecodeBackend`: candle tensor ops (CPU/fallback)
18///
19/// The backend is initialized with model weights and manages its own
20/// internal state (KV cache, buffers, cuBLAS handles, etc.).
21pub trait DecodeBackend: Send + Sync {
22    /// Execute a single decode step: one token in, logits out.
23    ///
24    /// - `token_id`: the input token
25    /// - `position`: sequence position (for RoPE)
26    /// - `cache_key`: identifies the sequence's KV cache
27    ///
28    /// Returns logits as a TensorRef [1, 1, vocab_size].
29    fn decode_step(&mut self, token_id: u32, position: usize, cache_key: &str)
30        -> Result<TensorRef>;
31
32    /// Initialize KV cache for a new sequence from prefill data.
33    ///
34    /// Called after prefill (which runs through the model's forward pass)
35    /// to hand off the KV cache to the decode backend.
36    ///
37    /// `kv_data`: per-layer (K, V) tensor pairs from the prefill pass.
38    /// `prefill_len`: number of tokens in the prefill.
39    fn init_kv_cache(
40        &mut self,
41        cache_key: &str,
42        kv_data: Vec<(TensorRef, TensorRef)>,
43        prefill_len: usize,
44    ) -> Result<()>;
45
46    /// Check if KV cache exists for a sequence.
47    fn has_kv_cache(&self, cache_key: &str) -> bool;
48
49    /// Release KV cache for a completed sequence.
50    fn release_kv_cache(&mut self, cache_key: &str);
51
52    /// Human-readable backend name (for logging).
53    fn name(&self) -> &str;
54}