1use crate::kv_cache::MockKvCacheHandle;
4use crate::tensor::MockTensor;
5use async_trait::async_trait;
6use ferrum_interfaces::{
7 model_executor::{
8 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
9 ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
10 },
11 ModelExecutor,
12};
13use ferrum_types::{DataType, Device, ModelInfo, ModelType, RequestId, Result};
14use std::collections::HashMap;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19pub struct MockModelExecutor {
22 info: ModelInfo,
23 prefill_latency: Duration,
24 decode_latency: Duration,
25 prefill_count: AtomicU64,
26 decode_count: AtomicU64,
27}
28
29impl MockModelExecutor {
30 pub fn new(vocab_size: usize, prefill_latency: Duration, decode_latency: Duration) -> Self {
31 let info = ModelInfo {
32 model_id: "mock-model".into(),
33 model_type: ModelType::Custom("mock".into()),
34 num_parameters: 1_000_000,
35 hidden_size: 768,
36 num_layers: 12,
37 num_heads: 12,
38 num_kv_heads: 12,
39 vocab_size,
40 max_sequence_length: 4096,
41 dtype: DataType::FP32,
42 device: Device::CPU,
43 version: Some("mock-1.0".into()),
44 license: None,
45 metadata: HashMap::new(),
46 };
47 Self {
48 info,
49 prefill_latency,
50 decode_latency,
51 prefill_count: AtomicU64::new(0),
52 decode_count: AtomicU64::new(0),
53 }
54 }
55
56 pub fn instant(vocab_size: usize) -> Self {
58 Self::new(vocab_size, Duration::ZERO, Duration::ZERO)
59 }
60
61 pub fn prefill_count(&self) -> u64 {
62 self.prefill_count.load(Ordering::Relaxed)
63 }
64
65 pub fn decode_count(&self) -> u64 {
66 self.decode_count.load(Ordering::Relaxed)
67 }
68}
69
70#[async_trait]
71impl ModelExecutor for MockModelExecutor {
72 fn info(&self) -> &ModelInfo {
73 &self.info
74 }
75
76 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
77 if !self.prefill_latency.is_zero() {
78 tokio::time::sleep(self.prefill_latency).await;
79 }
80 self.prefill_count.fetch_add(1, Ordering::Relaxed);
81
82 let batch_size = input.batch_size();
83 let seq_len = input.sequence_length();
84 let vocab_size = self.info.vocab_size;
85
86 let mut logits_data = vec![0.0f32; batch_size * seq_len * vocab_size];
89 for b in 0..batch_size {
90 for s in 0..seq_len {
91 let offset = (b * seq_len + s) * vocab_size;
92 if offset + 42 < logits_data.len() {
93 logits_data[offset + 42] = 1.0;
94 }
95 }
96 }
97 let logits =
98 MockTensor::from_f32(logits_data, &[batch_size, seq_len, vocab_size]).into_ref();
99
100 let kv_cache = Arc::new(MockKvCacheHandle::new(
101 RequestId::new(),
102 self.info.num_layers,
103 seq_len,
104 ));
105
106 Ok(PrefillOutput::new(logits, kv_cache))
107 }
108
109 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
110 if !self.decode_latency.is_zero() {
111 tokio::time::sleep(self.decode_latency).await;
112 }
113 self.decode_count.fetch_add(1, Ordering::Relaxed);
114
115 let batch_size = input.batch_size();
116 let vocab_size = self.info.vocab_size;
117
118 let mut logits_data = vec![0.0f32; batch_size * vocab_size];
120 for b in 0..batch_size {
121 let offset = b * vocab_size;
122 if offset + 42 < logits_data.len() {
123 logits_data[offset + 42] = 1.0;
124 }
125 }
126 let logits = MockTensor::from_f32(logits_data, &[batch_size, vocab_size]).into_ref();
127
128 Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
129 }
130
131 fn capabilities(&self) -> ExecutorCapabilities {
132 ExecutorCapabilities {
133 max_batch_size: 256,
134 max_sequence_length: 4096,
135 attention_mechanisms: vec![AttentionType::MultiHead],
136 supports_dynamic_batching: true,
137 supports_continuous_batching: true,
138 supports_speculative_decoding: false,
139 supports_tensor_parallelism: false,
140 supports_pipeline_parallelism: false,
141 supported_dtypes: vec![DataType::FP32],
142 supported_devices: vec![Device::CPU],
143 memory_requirements: MemoryRequirements {
144 parameter_memory: 0,
145 activation_memory_per_token: 0,
146 kv_cache_memory_per_token: 0,
147 overhead_memory: 0,
148 },
149 }
150 }
151
152 fn status(&self) -> ExecutorStatus {
153 ExecutorStatus {
154 state: ExecutorState::Ready,
155 is_ready: true,
156 current_batch_size: 0,
157 prefill_operations: self.prefill_count.load(Ordering::Relaxed),
158 decode_operations: self.decode_count.load(Ordering::Relaxed),
159 avg_prefill_time_ms: self.prefill_latency.as_millis() as f64,
160 avg_decode_time_ms: self.decode_latency.as_millis() as f64,
161 memory_usage: ExecutorMemoryUsage {
162 allocated_bytes: 0,
163 used_bytes: 0,
164 peak_bytes: 0,
165 utilization_percent: 0.0,
166 },
167 last_operation: None,
168 }
169 }
170}