1use crate::tensor::MockTensor;
12use async_trait::async_trait;
13use ferrum_interfaces::{
14 model_executor::{
15 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
16 ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
17 },
18 KvCacheHandle, KvCacheManager, ModelExecutor,
19};
20use ferrum_kv::attention::paged_attention;
21use ferrum_kv::managers::paged::{PagedKvCacheHandle, PagedKvCacheManager};
22use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, RequestId, Result};
23use std::collections::HashMap;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::sync::Arc;
26
27#[derive(Debug, Clone)]
29pub struct PagedExecutorConfig {
30 pub vocab_size: usize,
31 pub num_layers: usize,
32 pub num_heads: usize,
33 pub num_kv_heads: usize,
34 pub head_dim: usize,
35 pub max_sequence_length: usize,
36}
37
38impl Default for PagedExecutorConfig {
39 fn default() -> Self {
40 Self {
41 vocab_size: 256,
42 num_layers: 2,
43 num_heads: 4,
44 num_kv_heads: 4,
45 head_dim: 8,
46 max_sequence_length: 512,
47 }
48 }
49}
50
51pub struct PagedAttentionExecutor {
62 config: PagedExecutorConfig,
63 info: ModelInfo,
64 kv_manager: Arc<PagedKvCacheManager>,
66 prefill_count: AtomicU64,
67 decode_count: AtomicU64,
68}
69
70impl PagedAttentionExecutor {
71 pub fn new(config: PagedExecutorConfig, kv_manager: Arc<PagedKvCacheManager>) -> Self {
72 let info = ModelInfo {
73 model_id: "paged-test-model".into(),
74 model_type: ModelType::Custom("paged-test".into()),
75 num_parameters: 0,
76 hidden_size: config.num_heads * config.head_dim,
77 num_layers: config.num_layers,
78 num_heads: config.num_heads,
79 num_kv_heads: config.num_kv_heads,
80 vocab_size: config.vocab_size,
81 max_sequence_length: config.max_sequence_length,
82 dtype: DataType::FP32,
83 device: Device::CPU,
84 version: Some("test-1.0".into()),
85 license: None,
86 metadata: HashMap::new(),
87 };
88
89 Self {
90 config,
91 info,
92 kv_manager,
93 prefill_count: AtomicU64::new(0),
94 decode_count: AtomicU64::new(0),
95 }
96 }
97
98 pub fn prefill_count(&self) -> u64 {
99 self.prefill_count.load(Ordering::Relaxed)
100 }
101
102 pub fn decode_count(&self) -> u64 {
103 self.decode_count.load(Ordering::Relaxed)
104 }
105
106 fn token_embedding(&self, token_id: u32) -> Vec<f32> {
111 let kv_size = self.config.num_kv_heads * self.config.head_dim;
112 let mut emb = vec![0.0f32; kv_size];
113 for i in 0..kv_size {
115 emb[i] = ((token_id as f32 + 1.0) * (i as f32 + 1.0)).sin();
116 }
117 emb
118 }
119
120 fn attention_to_logits(&self, attn_output: &[f32]) -> Vec<f32> {
125 let vocab_size = self.config.vocab_size;
126 let mut logits = vec![0.0f32; vocab_size];
127 for (i, &val) in attn_output.iter().enumerate() {
128 let vocab_idx = i % vocab_size;
129 logits[vocab_idx] += val;
130 }
131 logits
132 }
133
134 fn as_paged_handle<'a>(handle: &'a dyn KvCacheHandle) -> Result<&'a PagedKvCacheHandle> {
136 handle
137 .as_any()
138 .downcast_ref::<PagedKvCacheHandle>()
139 .ok_or_else(|| FerrumError::internal("Expected PagedKvCacheHandle"))
140 }
141}
142
143#[async_trait]
144impl ModelExecutor for PagedAttentionExecutor {
145 fn info(&self) -> &ModelInfo {
146 &self.info
147 }
148
149 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
150 self.prefill_count.fetch_add(1, Ordering::Relaxed);
151
152 let batch_size = input.batch_size();
153 let seq_len = input.sequence_length();
154 let vocab_size = self.config.vocab_size;
155 let num_heads = self.config.num_heads;
156 let num_kv_heads = self.config.num_kv_heads;
157 let head_dim = self.config.head_dim;
158
159 let token_ids = input.input_ids.to_vec_u32()?;
161
162 let kv_handle = match &input.kv_cache {
164 Some(handle) => handle.clone(),
165 None => {
166 let alloc_request = ferrum_interfaces::kv_cache::AllocationRequest {
167 request_id: RequestId::new(),
168 initial_tokens: seq_len,
169 max_sequence_length: self.config.max_sequence_length,
170 num_layers: self.config.num_layers,
171 num_heads: num_kv_heads,
172 head_dim,
173 device: Device::CPU,
174 dtype: DataType::FP32,
175 priority: ferrum_types::Priority::Normal,
176 };
177 self.kv_manager.allocate(&alloc_request).await?
178 }
179 };
180 let paged_handle = Self::as_paged_handle(kv_handle.as_ref())?;
181
182 for pos in 0..seq_len {
184 let token_id = if pos < token_ids.len() {
185 token_ids[pos]
186 } else {
187 0
188 };
189 let embedding = self.token_embedding(token_id);
190
191 for layer in 0..self.config.num_layers {
193 self.kv_manager
194 .write_kv(paged_handle, layer, pos, &embedding, &embedding)?;
195 }
196 }
197
198 let last_layer = self.config.num_layers - 1;
201
202 let mut query = Vec::with_capacity(seq_len * num_heads * head_dim);
204 for pos in 0..seq_len {
205 let token_id = if pos < token_ids.len() {
206 token_ids[pos]
207 } else {
208 0
209 };
210 let emb = self.token_embedding(token_id);
211 let heads_per_kv = num_heads / num_kv_heads;
213 for kv_h in 0..num_kv_heads {
214 for _ in 0..heads_per_kv {
215 query.extend_from_slice(&emb[kv_h * head_dim..(kv_h + 1) * head_dim]);
216 }
217 }
218 }
219
220 let attn_output = paged_attention(
221 &query,
222 seq_len,
223 num_heads,
224 num_kv_heads,
225 head_dim,
226 &self.kv_manager,
227 paged_handle,
228 last_layer,
229 seq_len,
230 )?;
231
232 let q_head_size = num_heads * head_dim;
234 let mut logits_data = Vec::with_capacity(batch_size * seq_len * vocab_size);
235 for _b in 0..batch_size {
236 for s in 0..seq_len {
237 let attn_slice = &attn_output[s * q_head_size..(s + 1) * q_head_size];
238 let token_logits = self.attention_to_logits(attn_slice);
239 logits_data.extend_from_slice(&token_logits);
240 }
241 }
242
243 let logits =
244 MockTensor::from_f32(logits_data, &[batch_size, seq_len, vocab_size]).into_ref();
245
246 Ok(PrefillOutput::new(logits, kv_handle))
247 }
248
249 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
250 self.decode_count.fetch_add(1, Ordering::Relaxed);
251
252 let batch_size = input.batch_size();
253 let vocab_size = self.config.vocab_size;
254 let num_heads = self.config.num_heads;
255 let num_kv_heads = self.config.num_kv_heads;
256 let head_dim = self.config.head_dim;
257
258 let paged_handle = Self::as_paged_handle(input.kv_cache.as_ref())?;
259
260 let position = paged_handle.num_tokens();
262
263 let blocks_needed = (position + 1 + self.kv_manager.gpu_pool().block_size() - 1)
265 / self.kv_manager.gpu_pool().block_size();
266 let current_blocks = paged_handle.num_blocks();
267 if blocks_needed > current_blocks {
268 self.kv_manager
269 .allocate_blocks(paged_handle, blocks_needed - current_blocks)?;
270 }
271
272 let token_ids = input.input_ids.to_vec_u32()?;
274 let token_id = token_ids.first().copied().unwrap_or(0);
275
276 let embedding = self.token_embedding(token_id);
277
278 for layer in 0..self.config.num_layers {
280 self.kv_manager
281 .write_kv(paged_handle, layer, position, &embedding, &embedding)?;
282 }
283
284 paged_handle.set_num_tokens(position + 1);
286
287 let last_layer = self.config.num_layers - 1;
289
290 let mut query = Vec::with_capacity(num_heads * head_dim);
292 let heads_per_kv = num_heads / num_kv_heads;
293 for kv_h in 0..num_kv_heads {
294 for _ in 0..heads_per_kv {
295 query.extend_from_slice(&embedding[kv_h * head_dim..(kv_h + 1) * head_dim]);
296 }
297 }
298
299 let attn_output = paged_attention(
300 &query,
301 1,
302 num_heads,
303 num_kv_heads,
304 head_dim,
305 &self.kv_manager,
306 paged_handle,
307 last_layer,
308 position + 1,
309 )?;
310
311 let mut logits_data = Vec::with_capacity(batch_size * vocab_size);
313 for _b in 0..batch_size {
314 let token_logits = self.attention_to_logits(&attn_output);
315 logits_data.extend_from_slice(&token_logits);
316 }
317
318 let logits = MockTensor::from_f32(logits_data, &[batch_size, vocab_size]).into_ref();
319
320 Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
321 }
322
323 fn capabilities(&self) -> ExecutorCapabilities {
324 ExecutorCapabilities {
325 max_batch_size: 64,
326 max_sequence_length: self.config.max_sequence_length,
327 attention_mechanisms: vec![AttentionType::Paged],
328 supports_dynamic_batching: true,
329 supports_continuous_batching: true,
330 supports_speculative_decoding: false,
331 supports_tensor_parallelism: false,
332 supports_pipeline_parallelism: false,
333 supported_dtypes: vec![DataType::FP32],
334 supported_devices: vec![Device::CPU],
335 memory_requirements: MemoryRequirements {
336 parameter_memory: 0,
337 activation_memory_per_token: 0,
338 kv_cache_memory_per_token: (self.config.num_kv_heads
339 * self.config.head_dim
340 * 2
341 * self.config.num_layers
342 * 4) as u64 as usize,
343 overhead_memory: 0,
344 },
345 }
346 }
347
348 fn status(&self) -> ExecutorStatus {
349 ExecutorStatus {
350 state: ExecutorState::Ready,
351 is_ready: true,
352 current_batch_size: 0,
353 prefill_operations: self.prefill_count(),
354 decode_operations: self.decode_count(),
355 avg_prefill_time_ms: 0.0,
356 avg_decode_time_ms: 0.0,
357 memory_usage: ExecutorMemoryUsage {
358 allocated_bytes: 0,
359 used_bytes: 0,
360 peak_bytes: 0,
361 utilization_percent: 0.0,
362 },
363 last_operation: None,
364 }
365 }
366}