ferrum_interfaces/decode_backend.rs
1//! Decode backend abstraction.
2//!
3//! Different backends (CUDA, Metal, CPU/Candle) implement `DecodeBackend`
4//! to provide optimized decode execution. The `GenericDecodeExecutor` uses
5//! a DecodeBackend to execute the decode hot path.
6
7use crate::tensor::TensorRef;
8use crate::transformer::TransformerWeights;
9use ferrum_types::Result;
10
11/// Decode-phase execution backend.
12///
13/// Implements the actual computation for single-token decode steps.
14/// Different backends optimize for different hardware:
15/// - `CudaDecodeBackend`: cuBLAS + custom CUDA kernels, pre-allocated buffers
16/// - `MetalDecodeBackend`: Metal compute shaders
17/// - `CandleDecodeBackend`: candle tensor ops (CPU/fallback)
18///
19/// The backend is initialized with model weights and manages its own
20/// internal state (KV cache, buffers, cuBLAS handles, etc.).
21pub trait DecodeBackend: Send + Sync {
22 /// Execute a single decode step: one token in, logits out.
23 ///
24 /// - `token_id`: the input token
25 /// - `position`: sequence position (for RoPE)
26 /// - `cache_key`: identifies the sequence's KV cache
27 ///
28 /// Returns logits as a TensorRef [1, 1, vocab_size].
29 fn decode_step(&mut self, token_id: u32, position: usize, cache_key: &str)
30 -> Result<TensorRef>;
31
32 /// Initialize KV cache for a new sequence from prefill data.
33 ///
34 /// Called after prefill (which runs through the model's forward pass)
35 /// to hand off the KV cache to the decode backend.
36 ///
37 /// `kv_data`: per-layer (K, V) tensor pairs from the prefill pass.
38 /// `prefill_len`: number of tokens in the prefill.
39 fn init_kv_cache(
40 &mut self,
41 cache_key: &str,
42 kv_data: Vec<(TensorRef, TensorRef)>,
43 prefill_len: usize,
44 ) -> Result<()>;
45
46 /// Check if KV cache exists for a sequence.
47 fn has_kv_cache(&self, cache_key: &str) -> bool;
48
49 /// Release KV cache for a completed sequence.
50 fn release_kv_cache(&mut self, cache_key: &str);
51
52 /// Human-readable backend name (for logging).
53 fn name(&self) -> &str;
54}