Skip to main content

ferrum_testkit/
executor.rs

1//! Mock model executor with configurable latency for scheduling tests.
2
3use crate::kv_cache::MockKvCacheHandle;
4use crate::tensor::MockTensor;
5use async_trait::async_trait;
6use ferrum_interfaces::{
7    model_executor::{
8        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
9        ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
10    },
11    ModelExecutor,
12};
13use ferrum_types::{DataType, Device, ModelInfo, ModelType, RequestId, Result};
14use std::collections::HashMap;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19/// Mock model executor that simulates prefill/decode with configurable latency.
20/// No model weights, no GPU — pure async simulation.
21pub struct MockModelExecutor {
22    info: ModelInfo,
23    prefill_latency: Duration,
24    decode_latency: Duration,
25    prefill_count: AtomicU64,
26    decode_count: AtomicU64,
27}
28
29impl MockModelExecutor {
30    pub fn new(vocab_size: usize, prefill_latency: Duration, decode_latency: Duration) -> Self {
31        let info = ModelInfo {
32            model_id: "mock-model".into(),
33            model_type: ModelType::Custom("mock".into()),
34            num_parameters: 1_000_000,
35            hidden_size: 768,
36            num_layers: 12,
37            num_heads: 12,
38            num_kv_heads: 12,
39            vocab_size,
40            max_sequence_length: 4096,
41            dtype: DataType::FP32,
42            device: Device::CPU,
43            version: Some("mock-1.0".into()),
44            license: None,
45            metadata: HashMap::new(),
46        };
47        Self {
48            info,
49            prefill_latency,
50            decode_latency,
51            prefill_count: AtomicU64::new(0),
52            decode_count: AtomicU64::new(0),
53        }
54    }
55
56    /// Create with zero latency (for fast unit tests).
57    pub fn instant(vocab_size: usize) -> Self {
58        Self::new(vocab_size, Duration::ZERO, Duration::ZERO)
59    }
60
61    pub fn prefill_count(&self) -> u64 {
62        self.prefill_count.load(Ordering::Relaxed)
63    }
64
65    pub fn decode_count(&self) -> u64 {
66        self.decode_count.load(Ordering::Relaxed)
67    }
68}
69
70#[async_trait]
71impl ModelExecutor for MockModelExecutor {
72    fn info(&self) -> &ModelInfo {
73        &self.info
74    }
75
76    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
77        if !self.prefill_latency.is_zero() {
78            tokio::time::sleep(self.prefill_latency).await;
79        }
80        self.prefill_count.fetch_add(1, Ordering::Relaxed);
81
82        let batch_size = input.batch_size();
83        let seq_len = input.sequence_length();
84        let vocab_size = self.info.vocab_size;
85
86        // Return synthetic logits [batch, seq_len, vocab_size]
87        // Put a slight bias toward token 42 so greedy sampling is deterministic
88        let mut logits_data = vec![0.0f32; batch_size * seq_len * vocab_size];
89        for b in 0..batch_size {
90            for s in 0..seq_len {
91                let offset = (b * seq_len + s) * vocab_size;
92                if offset + 42 < logits_data.len() {
93                    logits_data[offset + 42] = 1.0;
94                }
95            }
96        }
97        let logits =
98            MockTensor::from_f32(logits_data, &[batch_size, seq_len, vocab_size]).into_ref();
99
100        let kv_cache = Arc::new(MockKvCacheHandle::new(
101            RequestId::new(),
102            self.info.num_layers,
103            seq_len,
104        ));
105
106        Ok(PrefillOutput::new(logits, kv_cache))
107    }
108
109    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
110        if !self.decode_latency.is_zero() {
111            tokio::time::sleep(self.decode_latency).await;
112        }
113        self.decode_count.fetch_add(1, Ordering::Relaxed);
114
115        let batch_size = input.batch_size();
116        let vocab_size = self.info.vocab_size;
117
118        // Return logits [batch, vocab_size] with bias toward token 42
119        let mut logits_data = vec![0.0f32; batch_size * vocab_size];
120        for b in 0..batch_size {
121            let offset = b * vocab_size;
122            if offset + 42 < logits_data.len() {
123                logits_data[offset + 42] = 1.0;
124            }
125        }
126        let logits = MockTensor::from_f32(logits_data, &[batch_size, vocab_size]).into_ref();
127
128        Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
129    }
130
131    fn capabilities(&self) -> ExecutorCapabilities {
132        ExecutorCapabilities {
133            max_batch_size: 256,
134            max_sequence_length: 4096,
135            attention_mechanisms: vec![AttentionType::MultiHead],
136            supports_dynamic_batching: true,
137            supports_continuous_batching: true,
138            supports_speculative_decoding: false,
139            supports_tensor_parallelism: false,
140            supports_pipeline_parallelism: false,
141            supported_dtypes: vec![DataType::FP32],
142            supported_devices: vec![Device::CPU],
143            memory_requirements: MemoryRequirements {
144                parameter_memory: 0,
145                activation_memory_per_token: 0,
146                kv_cache_memory_per_token: 0,
147                overhead_memory: 0,
148            },
149        }
150    }
151
152    fn status(&self) -> ExecutorStatus {
153        ExecutorStatus {
154            state: ExecutorState::Ready,
155            is_ready: true,
156            current_batch_size: 0,
157            prefill_operations: self.prefill_count.load(Ordering::Relaxed),
158            decode_operations: self.decode_count.load(Ordering::Relaxed),
159            avg_prefill_time_ms: self.prefill_latency.as_millis() as f64,
160            avg_decode_time_ms: self.decode_latency.as_millis() as f64,
161            memory_usage: ExecutorMemoryUsage {
162                allocated_bytes: 0,
163                used_bytes: 0,
164                peak_bytes: 0,
165                utilization_percent: 0.0,
166            },
167            last_operation: None,
168        }
169    }
170}