#[derive(Clone, Debug)]
pub struct LlmRuntimeConfig {
pub hidden_size: usize,
pub num_layers: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub vocab_size: usize,
pub max_seq_len: usize,
}
pub trait DecoderOnlyLLM: Send + Sync {
fn config(&self) -> &LlmRuntimeConfig;
fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
let _ = (cache_id, max_tokens);
}
fn kv_capacity(&self) -> usize {
self.config().max_seq_len
}
fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32>;
fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32>;
fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
batch
.iter()
.map(|(cid, tok, p)| self.decode(cid, *tok, *p))
.collect()
}
fn forward_verify(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
let mut out = Vec::with_capacity(tokens.len() * self.config().vocab_size);
let start_pos = 0u32; for (i, &tok) in tokens.iter().enumerate() {
out.extend_from_slice(&self.decode(cache_id, tok, start_pos + i as u32));
}
out
}
#[allow(clippy::type_complexity)]
fn unified_forward(
&mut self,
_items: &[(String, Vec<u32>, usize, bool)],
) -> std::result::Result<Vec<Option<Vec<f32>>>, ferrum_types::FerrumError> {
Err(ferrum_types::FerrumError::unsupported(
"unified_forward not implemented for this model",
))
}
fn release(&mut self, cache_id: &str);
fn truncate_kv(&mut self, cache_id: &str, new_len: usize) {
let _ = (cache_id, new_len);
panic!("truncate_kv not implemented for this DecoderOnlyLLM");
}
fn reset(&mut self) {}
}