use crate::tensor::MockTensor;
use async_trait::async_trait;
use ferrum_interfaces::{
model_executor::{
AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
},
KvCacheHandle, KvCacheManager, ModelExecutor,
};
use ferrum_kv::attention::paged_attention;
use ferrum_kv::managers::paged::{PagedKvCacheHandle, PagedKvCacheManager};
use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, RequestId, Result};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct PagedExecutorConfig {
pub vocab_size: usize,
pub num_layers: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub max_sequence_length: usize,
}
impl Default for PagedExecutorConfig {
fn default() -> Self {
Self {
vocab_size: 256,
num_layers: 2,
num_heads: 4,
num_kv_heads: 4,
head_dim: 8,
max_sequence_length: 512,
}
}
}
pub struct PagedAttentionExecutor {
config: PagedExecutorConfig,
info: ModelInfo,
kv_manager: Arc<PagedKvCacheManager>,
prefill_count: AtomicU64,
decode_count: AtomicU64,
}
impl PagedAttentionExecutor {
pub fn new(config: PagedExecutorConfig, kv_manager: Arc<PagedKvCacheManager>) -> Self {
let info = ModelInfo {
model_id: "paged-test-model".into(),
model_type: ModelType::Custom("paged-test".into()),
num_parameters: 0,
hidden_size: config.num_heads * config.head_dim,
num_layers: config.num_layers,
num_heads: config.num_heads,
num_kv_heads: config.num_kv_heads,
vocab_size: config.vocab_size,
max_sequence_length: config.max_sequence_length,
dtype: DataType::FP32,
device: Device::CPU,
version: Some("test-1.0".into()),
license: None,
metadata: HashMap::new(),
};
Self {
config,
info,
kv_manager,
prefill_count: AtomicU64::new(0),
decode_count: AtomicU64::new(0),
}
}
pub fn prefill_count(&self) -> u64 {
self.prefill_count.load(Ordering::Relaxed)
}
pub fn decode_count(&self) -> u64 {
self.decode_count.load(Ordering::Relaxed)
}
fn token_embedding(&self, token_id: u32) -> Vec<f32> {
let kv_size = self.config.num_kv_heads * self.config.head_dim;
let mut emb = vec![0.0f32; kv_size];
for i in 0..kv_size {
emb[i] = ((token_id as f32 + 1.0) * (i as f32 + 1.0)).sin();
}
emb
}
fn attention_to_logits(&self, attn_output: &[f32]) -> Vec<f32> {
let vocab_size = self.config.vocab_size;
let mut logits = vec![0.0f32; vocab_size];
for (i, &val) in attn_output.iter().enumerate() {
let vocab_idx = i % vocab_size;
logits[vocab_idx] += val;
}
logits
}
fn as_paged_handle(handle: &dyn KvCacheHandle) -> Result<&PagedKvCacheHandle> {
handle
.as_any()
.downcast_ref::<PagedKvCacheHandle>()
.ok_or_else(|| FerrumError::internal("Expected PagedKvCacheHandle"))
}
}
#[async_trait]
impl ModelExecutor for PagedAttentionExecutor {
fn info(&self) -> &ModelInfo {
&self.info
}
async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
self.prefill_count.fetch_add(1, Ordering::Relaxed);
let batch_size = input.batch_size();
let seq_len = input.sequence_length();
let vocab_size = self.config.vocab_size;
let num_heads = self.config.num_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim;
let token_ids = input.input_ids.to_vec_u32()?;
let kv_handle = match &input.kv_cache {
Some(handle) => handle.clone(),
None => {
let alloc_request = ferrum_interfaces::kv_cache::AllocationRequest {
request_id: RequestId::new(),
initial_tokens: seq_len,
max_sequence_length: self.config.max_sequence_length,
num_layers: self.config.num_layers,
num_heads: num_kv_heads,
head_dim,
device: Device::CPU,
dtype: DataType::FP32,
priority: ferrum_types::Priority::Normal,
};
self.kv_manager.allocate(&alloc_request).await?
}
};
let paged_handle = Self::as_paged_handle(kv_handle.as_ref())?;
for pos in 0..seq_len {
let token_id = if pos < token_ids.len() {
token_ids[pos]
} else {
0
};
let embedding = self.token_embedding(token_id);
for layer in 0..self.config.num_layers {
self.kv_manager
.write_kv(paged_handle, layer, pos, &embedding, &embedding)?;
}
}
let last_layer = self.config.num_layers - 1;
let mut query = Vec::with_capacity(seq_len * num_heads * head_dim);
for pos in 0..seq_len {
let token_id = if pos < token_ids.len() {
token_ids[pos]
} else {
0
};
let emb = self.token_embedding(token_id);
let heads_per_kv = num_heads / num_kv_heads;
for kv_h in 0..num_kv_heads {
for _ in 0..heads_per_kv {
query.extend_from_slice(&emb[kv_h * head_dim..(kv_h + 1) * head_dim]);
}
}
}
let attn_output = paged_attention(
&query,
seq_len,
num_heads,
num_kv_heads,
head_dim,
&self.kv_manager,
paged_handle,
last_layer,
seq_len,
)?;
let q_head_size = num_heads * head_dim;
let mut logits_data = Vec::with_capacity(batch_size * seq_len * vocab_size);
for _b in 0..batch_size {
for s in 0..seq_len {
let attn_slice = &attn_output[s * q_head_size..(s + 1) * q_head_size];
let token_logits = self.attention_to_logits(attn_slice);
logits_data.extend_from_slice(&token_logits);
}
}
let logits =
MockTensor::from_f32(logits_data, &[batch_size, seq_len, vocab_size]).into_ref();
Ok(PrefillOutput::new(logits, kv_handle))
}
async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
self.decode_count.fetch_add(1, Ordering::Relaxed);
let batch_size = input.batch_size();
let vocab_size = self.config.vocab_size;
let num_heads = self.config.num_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim;
let paged_handle = Self::as_paged_handle(input.kv_cache.as_ref())?;
let position = paged_handle.num_tokens();
let blocks_needed = (position + 1).div_ceil(self.kv_manager.gpu_pool().block_size());
let current_blocks = paged_handle.num_blocks();
if blocks_needed > current_blocks {
self.kv_manager
.allocate_blocks(paged_handle, blocks_needed - current_blocks)?;
}
let token_ids = input.input_ids.to_vec_u32()?;
let token_id = token_ids.first().copied().unwrap_or(0);
let embedding = self.token_embedding(token_id);
for layer in 0..self.config.num_layers {
self.kv_manager
.write_kv(paged_handle, layer, position, &embedding, &embedding)?;
}
paged_handle.set_num_tokens(position + 1);
let last_layer = self.config.num_layers - 1;
let mut query = Vec::with_capacity(num_heads * head_dim);
let heads_per_kv = num_heads / num_kv_heads;
for kv_h in 0..num_kv_heads {
for _ in 0..heads_per_kv {
query.extend_from_slice(&embedding[kv_h * head_dim..(kv_h + 1) * head_dim]);
}
}
let attn_output = paged_attention(
&query,
1,
num_heads,
num_kv_heads,
head_dim,
&self.kv_manager,
paged_handle,
last_layer,
position + 1,
)?;
let mut logits_data = Vec::with_capacity(batch_size * vocab_size);
for _b in 0..batch_size {
let token_logits = self.attention_to_logits(&attn_output);
logits_data.extend_from_slice(&token_logits);
}
let logits = MockTensor::from_f32(logits_data, &[batch_size, vocab_size]).into_ref();
Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
}
fn capabilities(&self) -> ExecutorCapabilities {
ExecutorCapabilities {
max_batch_size: 64,
max_sequence_length: self.config.max_sequence_length,
attention_mechanisms: vec![AttentionType::Paged],
supports_dynamic_batching: true,
supports_continuous_batching: true,
supports_speculative_decoding: false,
supports_tensor_parallelism: false,
supports_pipeline_parallelism: false,
supported_dtypes: vec![DataType::FP32],
supported_devices: vec![Device::CPU],
memory_requirements: MemoryRequirements {
parameter_memory: 0,
activation_memory_per_token: 0,
kv_cache_memory_per_token: (self.config.num_kv_heads
* self.config.head_dim
* 2
* self.config.num_layers
* 4) as u64 as usize,
overhead_memory: 0,
},
}
}
fn status(&self) -> ExecutorStatus {
ExecutorStatus {
state: ExecutorState::Ready,
is_ready: true,
current_batch_size: 0,
prefill_operations: self.prefill_count(),
decode_operations: self.decode_count(),
avg_prefill_time_ms: 0.0,
avg_decode_time_ms: 0.0,
memory_usage: ExecutorMemoryUsage {
allocated_bytes: 0,
used_bytes: 0,
peak_bytes: 0,
utilization_percent: 0.0,
},
last_operation: None,
}
}
}