Skip to main content

ferrum_testkit/
paged_executor.rs

1//! Model executor that uses PagedAttention KV cache.
2//!
3//! Unlike MockModelExecutor (which ignores KV cache), this executor:
4//! - Writes K/V vectors to paged blocks during prefill and decode
5//! - Reads K/V through block table indirection for attention
6//! - Produces logits via the paged attention output
7//!
8//! Uses identity projections (Q=K=V=input embedding) for deterministic,
9//! verifiable behavior without model weights.
10
11use crate::tensor::MockTensor;
12use async_trait::async_trait;
13use ferrum_interfaces::{
14    model_executor::{
15        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
16        ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
17    },
18    KvCacheHandle, KvCacheManager, ModelExecutor,
19};
20use ferrum_kv::attention::paged_attention;
21use ferrum_kv::managers::paged::{PagedKvCacheHandle, PagedKvCacheManager};
22use ferrum_types::{DataType, Device, FerrumError, ModelInfo, ModelType, RequestId, Result};
23use std::collections::HashMap;
24use std::sync::atomic::{AtomicU64, Ordering};
25use std::sync::Arc;
26
27/// Configuration for the paged attention executor.
28#[derive(Debug, Clone)]
29pub struct PagedExecutorConfig {
30    pub vocab_size: usize,
31    pub num_layers: usize,
32    pub num_heads: usize,
33    pub num_kv_heads: usize,
34    pub head_dim: usize,
35    pub max_sequence_length: usize,
36}
37
38impl Default for PagedExecutorConfig {
39    fn default() -> Self {
40        Self {
41            vocab_size: 256,
42            num_layers: 2,
43            num_heads: 4,
44            num_kv_heads: 4,
45            head_dim: 8,
46            max_sequence_length: 512,
47        }
48    }
49}
50
51/// A model executor that actually uses paged KV cache for attention.
52///
53/// Uses identity projections: for each token, the embedding is a one-hot
54/// vector of length `num_kv_heads * head_dim` derived from the token ID.
55/// Q = K = V = embedding.  This makes attention outputs deterministic
56/// and verifiable.
57///
58/// Logits are produced by summing attention output elements per head and
59/// distributing across vocab positions, so different attention patterns
60/// produce different token predictions.
61pub struct PagedAttentionExecutor {
62    config: PagedExecutorConfig,
63    info: ModelInfo,
64    /// Shared with the engine's KV cache manager.
65    kv_manager: Arc<PagedKvCacheManager>,
66    prefill_count: AtomicU64,
67    decode_count: AtomicU64,
68}
69
70impl PagedAttentionExecutor {
71    pub fn new(config: PagedExecutorConfig, kv_manager: Arc<PagedKvCacheManager>) -> Self {
72        let info = ModelInfo {
73            model_id: "paged-test-model".into(),
74            model_type: ModelType::Custom("paged-test".into()),
75            num_parameters: 0,
76            hidden_size: config.num_heads * config.head_dim,
77            num_layers: config.num_layers,
78            num_heads: config.num_heads,
79            num_kv_heads: config.num_kv_heads,
80            vocab_size: config.vocab_size,
81            max_sequence_length: config.max_sequence_length,
82            dtype: DataType::FP32,
83            device: Device::CPU,
84            version: Some("test-1.0".into()),
85            license: None,
86            metadata: HashMap::new(),
87        };
88
89        Self {
90            config,
91            info,
92            kv_manager,
93            prefill_count: AtomicU64::new(0),
94            decode_count: AtomicU64::new(0),
95        }
96    }
97
98    pub fn prefill_count(&self) -> u64 {
99        self.prefill_count.load(Ordering::Relaxed)
100    }
101
102    pub fn decode_count(&self) -> u64 {
103        self.decode_count.load(Ordering::Relaxed)
104    }
105
106    /// Create an embedding vector from a token ID.
107    ///
108    /// Returns a vector of length `num_kv_heads * head_dim` where each
109    /// element is derived from the token ID for deterministic behavior.
110    fn token_embedding(&self, token_id: u32) -> Vec<f32> {
111        let kv_size = self.config.num_kv_heads * self.config.head_dim;
112        let mut emb = vec![0.0f32; kv_size];
113        // Spread the token value across dimensions to create distinct patterns
114        for i in 0..kv_size {
115            emb[i] = ((token_id as f32 + 1.0) * (i as f32 + 1.0)).sin();
116        }
117        emb
118    }
119
120    /// Convert attention output to logits [vocab_size].
121    ///
122    /// Simple linear projection: each attention output dimension contributes
123    /// to vocab positions via modular mapping.
124    fn attention_to_logits(&self, attn_output: &[f32]) -> Vec<f32> {
125        let vocab_size = self.config.vocab_size;
126        let mut logits = vec![0.0f32; vocab_size];
127        for (i, &val) in attn_output.iter().enumerate() {
128            let vocab_idx = i % vocab_size;
129            logits[vocab_idx] += val;
130        }
131        logits
132    }
133
134    /// Get the paged handle from an Arc<dyn KvCacheHandle>.
135    fn as_paged_handle<'a>(handle: &'a dyn KvCacheHandle) -> Result<&'a PagedKvCacheHandle> {
136        handle
137            .as_any()
138            .downcast_ref::<PagedKvCacheHandle>()
139            .ok_or_else(|| FerrumError::internal("Expected PagedKvCacheHandle"))
140    }
141}
142
143#[async_trait]
144impl ModelExecutor for PagedAttentionExecutor {
145    fn info(&self) -> &ModelInfo {
146        &self.info
147    }
148
149    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
150        self.prefill_count.fetch_add(1, Ordering::Relaxed);
151
152        let batch_size = input.batch_size();
153        let seq_len = input.sequence_length();
154        let vocab_size = self.config.vocab_size;
155        let num_heads = self.config.num_heads;
156        let num_kv_heads = self.config.num_kv_heads;
157        let head_dim = self.config.head_dim;
158
159        // Extract token IDs from input tensor
160        let token_ids = input.input_ids.to_vec_u32()?;
161
162        // Use pre-allocated KV cache from engine, or allocate if not provided
163        let kv_handle = match &input.kv_cache {
164            Some(handle) => handle.clone(),
165            None => {
166                let alloc_request = ferrum_interfaces::kv_cache::AllocationRequest {
167                    request_id: RequestId::new(),
168                    initial_tokens: seq_len,
169                    max_sequence_length: self.config.max_sequence_length,
170                    num_layers: self.config.num_layers,
171                    num_heads: num_kv_heads,
172                    head_dim,
173                    device: Device::CPU,
174                    dtype: DataType::FP32,
175                    priority: ferrum_types::Priority::Normal,
176                };
177                self.kv_manager.allocate(&alloc_request).await?
178            }
179        };
180        let paged_handle = Self::as_paged_handle(kv_handle.as_ref())?;
181
182        // For each token, compute embedding and write K/V to paged cache
183        for pos in 0..seq_len {
184            let token_id = if pos < token_ids.len() {
185                token_ids[pos]
186            } else {
187                0
188            };
189            let embedding = self.token_embedding(token_id);
190
191            // Identity projection: K = V = embedding
192            for layer in 0..self.config.num_layers {
193                self.kv_manager
194                    .write_kv(paged_handle, layer, pos, &embedding, &embedding)?;
195            }
196        }
197
198        // Run paged attention for all layers, accumulate output for last layer
199        // For logits, we only need the last layer's attention output
200        let last_layer = self.config.num_layers - 1;
201
202        // Build query for all positions (prefill): Q = embeddings
203        let mut query = Vec::with_capacity(seq_len * num_heads * head_dim);
204        for pos in 0..seq_len {
205            let token_id = if pos < token_ids.len() {
206                token_ids[pos]
207            } else {
208                0
209            };
210            let emb = self.token_embedding(token_id);
211            // Expand KV heads to query heads if num_heads > num_kv_heads (GQA)
212            let heads_per_kv = num_heads / num_kv_heads;
213            for kv_h in 0..num_kv_heads {
214                for _ in 0..heads_per_kv {
215                    query.extend_from_slice(&emb[kv_h * head_dim..(kv_h + 1) * head_dim]);
216                }
217            }
218        }
219
220        let attn_output = paged_attention(
221            &query,
222            seq_len,
223            num_heads,
224            num_kv_heads,
225            head_dim,
226            &self.kv_manager,
227            paged_handle,
228            last_layer,
229            seq_len,
230        )?;
231
232        // Build logits [batch, seq_len, vocab_size]
233        let q_head_size = num_heads * head_dim;
234        let mut logits_data = Vec::with_capacity(batch_size * seq_len * vocab_size);
235        for _b in 0..batch_size {
236            for s in 0..seq_len {
237                let attn_slice = &attn_output[s * q_head_size..(s + 1) * q_head_size];
238                let token_logits = self.attention_to_logits(attn_slice);
239                logits_data.extend_from_slice(&token_logits);
240            }
241        }
242
243        let logits =
244            MockTensor::from_f32(logits_data, &[batch_size, seq_len, vocab_size]).into_ref();
245
246        Ok(PrefillOutput::new(logits, kv_handle))
247    }
248
249    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
250        self.decode_count.fetch_add(1, Ordering::Relaxed);
251
252        let batch_size = input.batch_size();
253        let vocab_size = self.config.vocab_size;
254        let num_heads = self.config.num_heads;
255        let num_kv_heads = self.config.num_kv_heads;
256        let head_dim = self.config.head_dim;
257
258        let paged_handle = Self::as_paged_handle(input.kv_cache.as_ref())?;
259
260        // Current position = number of tokens already cached
261        let position = paged_handle.num_tokens();
262
263        // We may need to allocate an additional block
264        let blocks_needed = (position + 1 + self.kv_manager.gpu_pool().block_size() - 1)
265            / self.kv_manager.gpu_pool().block_size();
266        let current_blocks = paged_handle.num_blocks();
267        if blocks_needed > current_blocks {
268            self.kv_manager
269                .allocate_blocks(paged_handle, blocks_needed - current_blocks)?;
270        }
271
272        // Extract the new token ID
273        let token_ids = input.input_ids.to_vec_u32()?;
274        let token_id = token_ids.first().copied().unwrap_or(0);
275
276        let embedding = self.token_embedding(token_id);
277
278        // Write K/V for the new token at `position` in each layer
279        for layer in 0..self.config.num_layers {
280            self.kv_manager
281                .write_kv(paged_handle, layer, position, &embedding, &embedding)?;
282        }
283
284        // Update token count
285        paged_handle.set_num_tokens(position + 1);
286
287        // Run attention for the new token (decode: q_tokens=1, kv_len=position+1)
288        let last_layer = self.config.num_layers - 1;
289
290        // Build query for the single new token
291        let mut query = Vec::with_capacity(num_heads * head_dim);
292        let heads_per_kv = num_heads / num_kv_heads;
293        for kv_h in 0..num_kv_heads {
294            for _ in 0..heads_per_kv {
295                query.extend_from_slice(&embedding[kv_h * head_dim..(kv_h + 1) * head_dim]);
296            }
297        }
298
299        let attn_output = paged_attention(
300            &query,
301            1,
302            num_heads,
303            num_kv_heads,
304            head_dim,
305            &self.kv_manager,
306            paged_handle,
307            last_layer,
308            position + 1,
309        )?;
310
311        // Convert attention output to logits [batch, vocab_size]
312        let mut logits_data = Vec::with_capacity(batch_size * vocab_size);
313        for _b in 0..batch_size {
314            let token_logits = self.attention_to_logits(&attn_output);
315            logits_data.extend_from_slice(&token_logits);
316        }
317
318        let logits = MockTensor::from_f32(logits_data, &[batch_size, vocab_size]).into_ref();
319
320        Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
321    }
322
323    fn capabilities(&self) -> ExecutorCapabilities {
324        ExecutorCapabilities {
325            max_batch_size: 64,
326            max_sequence_length: self.config.max_sequence_length,
327            attention_mechanisms: vec![AttentionType::Paged],
328            supports_dynamic_batching: true,
329            supports_continuous_batching: true,
330            supports_speculative_decoding: false,
331            supports_tensor_parallelism: false,
332            supports_pipeline_parallelism: false,
333            supported_dtypes: vec![DataType::FP32],
334            supported_devices: vec![Device::CPU],
335            memory_requirements: MemoryRequirements {
336                parameter_memory: 0,
337                activation_memory_per_token: 0,
338                kv_cache_memory_per_token: (self.config.num_kv_heads
339                    * self.config.head_dim
340                    * 2
341                    * self.config.num_layers
342                    * 4) as u64 as usize,
343                overhead_memory: 0,
344            },
345        }
346    }
347
348    fn status(&self) -> ExecutorStatus {
349        ExecutorStatus {
350            state: ExecutorState::Ready,
351            is_ready: true,
352            current_batch_size: 0,
353            prefill_operations: self.prefill_count(),
354            decode_operations: self.decode_count(),
355            avg_prefill_time_ms: 0.0,
356            avg_decode_time_ms: 0.0,
357            memory_usage: ExecutorMemoryUsage {
358                allocated_bytes: 0,
359                used_bytes: 0,
360                peak_bytes: 0,
361                utilization_percent: 0.0,
362            },
363            last_operation: None,
364        }
365    }
366}