1use async_trait::async_trait;
4use ferrum_interfaces::{
5 model_executor::{
6 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
7 ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
8 },
9 BlockTable, KvCacheHandle, ModelExecutor, TensorFactory,
10};
11use ferrum_types::{DataType, Device, ModelInfo, ModelType, Result};
12use std::sync::Arc;
13use tracing::debug;
14
15pub struct StubModelExecutor {
19 info: ModelInfo,
20 tensor_factory: Arc<dyn TensorFactory>,
21}
22
23impl StubModelExecutor {
24 pub fn new(
25 model_id: impl Into<ferrum_types::ModelId>,
26 vocab_size: usize,
27 tensor_factory: Arc<dyn TensorFactory>,
28 ) -> Self {
29 let info = ModelInfo {
30 model_id: model_id.into(),
31 model_type: ModelType::Custom("stub".into()),
32 num_parameters: 1_000_000,
33 hidden_size: 768,
34 num_layers: 12,
35 num_heads: 12,
36 num_kv_heads: 12,
37 vocab_size,
38 max_sequence_length: 2048,
39 dtype: DataType::FP16,
40 device: Device::CPU,
41 version: Some("mvp-stub".into()),
42 license: Some("Apache-2.0".into()),
43 metadata: std::collections::HashMap::new(),
44 };
45
46 debug!("Created StubModelExecutor: vocab={}", vocab_size);
47
48 Self {
49 info,
50 tensor_factory,
51 }
52 }
53}
54
55#[async_trait]
56impl ModelExecutor for StubModelExecutor {
57 fn info(&self) -> &ModelInfo {
58 &self.info
59 }
60
61 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
62 let batch_size = input.batch_size();
63 let seq_len = input.sequence_length();
64 let vocab_size = self.info.vocab_size;
65
66 debug!("Stub prefill: batch={}, seq_len={}", batch_size, seq_len);
67
68 let logits = self.tensor_factory.zeros(
70 &[batch_size, seq_len, vocab_size],
71 DataType::FP32,
72 &self.info.device,
73 )?;
74
75 let kv_cache = create_stub_kv_cache(
77 ferrum_types::RequestId::new(),
78 self.info.num_layers,
79 seq_len,
80 );
81
82 Ok(PrefillOutput::new(logits, kv_cache))
83 }
84
85 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
86 let batch_size = input.batch_size();
87 let vocab_size = self.info.vocab_size;
88
89 debug!("Stub decode: batch={}", batch_size);
90
91 let logits = self.tensor_factory.zeros(
92 &[batch_size, vocab_size],
93 DataType::FP32,
94 &self.info.device,
95 )?;
96
97 Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
98 }
99
100 fn capabilities(&self) -> ExecutorCapabilities {
101 ExecutorCapabilities {
102 max_batch_size: 32,
103 max_sequence_length: 2048,
104 attention_mechanisms: vec![AttentionType::MultiHead],
105 supports_dynamic_batching: false,
106 supports_continuous_batching: false,
107 supports_speculative_decoding: false,
108 supports_tensor_parallelism: false,
109 supports_pipeline_parallelism: false,
110 supported_dtypes: vec![DataType::FP32, DataType::FP16],
111 supported_devices: vec![Device::CPU],
112 memory_requirements: MemoryRequirements {
113 parameter_memory: 4 * 1024 * 1024, activation_memory_per_token: 1024,
115 kv_cache_memory_per_token: 512,
116 overhead_memory: 1024 * 1024,
117 },
118 }
119 }
120
121 fn status(&self) -> ExecutorStatus {
122 ExecutorStatus {
123 state: ExecutorState::Ready,
124 is_ready: true,
125 current_batch_size: 0,
126 prefill_operations: 0,
127 decode_operations: 0,
128 avg_prefill_time_ms: 0.0,
129 avg_decode_time_ms: 0.0,
130 memory_usage: ExecutorMemoryUsage {
131 allocated_bytes: 1024 * 1024,
132 used_bytes: 512 * 1024,
133 peak_bytes: 1024 * 1024,
134 utilization_percent: 50.0,
135 },
136 last_operation: None,
137 }
138 }
139}
140
141fn create_stub_kv_cache(
143 request_id: ferrum_types::RequestId,
144 num_layers: usize,
145 seq_len: usize,
146) -> Arc<dyn KvCacheHandle> {
147 #[derive(Debug)]
148 struct StubKvCache {
149 request_id: ferrum_types::RequestId,
150 block_table: BlockTable,
151 num_layers: usize,
152 }
153
154 impl KvCacheHandle for StubKvCache {
155 fn block_table(&self) -> &BlockTable {
156 &self.block_table
157 }
158
159 fn block_table_mut(&mut self) -> &mut BlockTable {
160 &mut self.block_table
161 }
162
163 fn as_any(&self) -> &dyn std::any::Any {
164 self
165 }
166
167 fn device(&self) -> Device {
168 Device::CPU
169 }
170
171 fn num_layers(&self) -> usize {
172 self.num_layers
173 }
174
175 fn num_heads(&self) -> usize {
176 12
177 }
178
179 fn head_dim(&self) -> usize {
180 64
181 }
182
183 fn key_cache(&self, _layer: usize) -> Result<Option<ferrum_interfaces::TensorRef>> {
184 Ok(None)
185 }
186
187 fn value_cache(&self, _layer: usize) -> Result<Option<ferrum_interfaces::TensorRef>> {
188 Ok(None)
189 }
190
191 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
192 Err(ferrum_types::FerrumError::unsupported("Stub cache clone"))
193 }
194
195 fn stats(&self) -> ferrum_interfaces::kv_cache::CacheHandleStats {
196 ferrum_interfaces::kv_cache::CacheHandleStats {
197 memory_bytes: 1024,
198 blocks_allocated: 1,
199 tokens_stored: self.block_table.sequence_length,
200 utilization: 0.5,
201 last_access: std::time::Instant::now(),
202 }
203 }
204
205 fn is_valid(&self) -> bool {
206 true
207 }
208
209 fn cache_id(&self) -> String {
210 format!("stub_{}", self.request_id.to_string())
211 }
212 }
213
214 let mut block_table = BlockTable::new(16);
215 block_table.sequence_length = seq_len;
216
217 Arc::new(StubKvCache {
218 request_id,
219 block_table,
220 num_layers,
221 })
222}