ferrum_interfaces/decode_backend.rs
1//! Decode backend abstraction.
2//!
3//! Different backends (CUDA, Metal, CPU/Candle) implement `DecodeBackend`
4//! to provide optimized decode execution. The `GenericDecodeExecutor` uses
5//! a DecodeBackend to execute the decode hot path.
6
7use crate::tensor::TensorRef;
8use ferrum_types::Result;
9
10/// Decode-phase execution backend.
11///
12/// Implements the actual computation for single-token decode steps.
13/// Different backends optimize for different hardware:
14/// - `CudaDecodeBackend`: cuBLAS + custom CUDA kernels, pre-allocated buffers
15/// - `MetalDecodeBackend`: Metal compute shaders
16/// - `CandleDecodeBackend`: candle tensor ops (CPU/fallback)
17///
18/// The backend is initialized with model weights and manages its own
19/// internal state (KV cache, buffers, cuBLAS handles, etc.).
20pub trait DecodeBackend: Send + Sync {
21 /// Execute a single decode step: one token in, logits out.
22 ///
23 /// - `token_id`: the input token
24 /// - `position`: sequence position (for RoPE)
25 /// - `cache_key`: identifies the sequence's KV cache
26 ///
27 /// Returns logits as a TensorRef [1, 1, vocab_size].
28 fn decode_step(&mut self, token_id: u32, position: usize, cache_key: &str)
29 -> Result<TensorRef>;
30
31 /// Initialize KV cache for a new sequence from prefill data.
32 ///
33 /// Called after prefill (which runs through the model's forward pass)
34 /// to hand off the KV cache to the decode backend.
35 ///
36 /// `kv_data`: per-layer (K, V) tensor pairs from the prefill pass.
37 /// `prefill_len`: number of tokens in the prefill.
38 fn init_kv_cache(
39 &mut self,
40 cache_key: &str,
41 kv_data: Vec<(TensorRef, TensorRef)>,
42 prefill_len: usize,
43 ) -> Result<()>;
44
45 /// Check if KV cache exists for a sequence.
46 fn has_kv_cache(&self, cache_key: &str) -> bool;
47
48 /// Release KV cache for a completed sequence.
49 fn release_kv_cache(&mut self, cache_key: &str);
50
51 /// Human-readable backend name (for logging).
52 fn name(&self) -> &str;
53}