1use async_trait::async_trait;
4use ferrum_interfaces::{
5 model_executor::{
6 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
7 ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
8 },
9 BlockTable, ComputeBackend, KvCacheHandle, ModelExecutor,
10};
11use ferrum_types::{DataType, Device, ModelInfo, ModelType, Result};
12use std::sync::Arc;
13use tracing::debug;
14
15pub struct StubModelExecutor {
19 info: ModelInfo,
20 compute_backend: Arc<dyn ComputeBackend>,
21}
22
23impl StubModelExecutor {
24 pub fn new(
25 model_id: impl Into<ferrum_types::ModelId>,
26 vocab_size: usize,
27 compute_backend: Arc<dyn ComputeBackend>,
28 ) -> Self {
29 let info = ModelInfo {
30 model_id: model_id.into(),
31 model_type: ModelType::Custom("stub".into()),
32 num_parameters: 1_000_000,
33 hidden_size: 768,
34 num_layers: 12,
35 num_heads: 12,
36 num_kv_heads: 12,
37 vocab_size,
38 max_sequence_length: 2048,
39 dtype: DataType::FP16,
40 device: Device::CPU,
41 version: Some("mvp-stub".into()),
42 license: Some("Apache-2.0".into()),
43 metadata: std::collections::HashMap::new(),
44 };
45
46 debug!("Created StubModelExecutor: vocab={}", vocab_size);
47
48 Self {
49 info,
50 compute_backend,
51 }
52 }
53}
54
55#[async_trait]
56impl ModelExecutor for StubModelExecutor {
57 fn info(&self) -> &ModelInfo {
58 &self.info
59 }
60
61 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
62 let batch_size = input.batch_size();
63 let seq_len = input.sequence_length();
64 let vocab_size = self.info.vocab_size;
65
66 debug!("Stub prefill: batch={}, seq_len={}", batch_size, seq_len);
67
68 let factory = self.compute_backend.tensor_factory();
70 let logits = factory.zeros(
71 &[batch_size, seq_len, vocab_size],
72 DataType::FP32,
73 &self.info.device,
74 )?;
75
76 let kv_cache = create_stub_kv_cache(
78 ferrum_types::RequestId::new(),
79 self.info.num_layers,
80 seq_len,
81 );
82
83 Ok(PrefillOutput::new(logits, kv_cache))
84 }
85
86 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
87 let batch_size = input.batch_size();
88 let vocab_size = self.info.vocab_size;
89
90 debug!("Stub decode: batch={}", batch_size);
91
92 let factory = self.compute_backend.tensor_factory();
93 let logits = factory.zeros(&[batch_size, vocab_size], DataType::FP32, &self.info.device)?;
94
95 Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
96 }
97
98 fn capabilities(&self) -> ExecutorCapabilities {
99 ExecutorCapabilities {
100 max_batch_size: 32,
101 max_sequence_length: 2048,
102 attention_mechanisms: vec![AttentionType::MultiHead],
103 supports_dynamic_batching: false,
104 supports_continuous_batching: false,
105 supports_speculative_decoding: false,
106 supports_tensor_parallelism: false,
107 supports_pipeline_parallelism: false,
108 supported_dtypes: vec![DataType::FP32, DataType::FP16],
109 supported_devices: vec![Device::CPU],
110 memory_requirements: MemoryRequirements {
111 parameter_memory: 4 * 1024 * 1024, activation_memory_per_token: 1024,
113 kv_cache_memory_per_token: 512,
114 overhead_memory: 1024 * 1024,
115 },
116 }
117 }
118
119 fn status(&self) -> ExecutorStatus {
120 ExecutorStatus {
121 state: ExecutorState::Ready,
122 is_ready: true,
123 current_batch_size: 0,
124 prefill_operations: 0,
125 decode_operations: 0,
126 avg_prefill_time_ms: 0.0,
127 avg_decode_time_ms: 0.0,
128 memory_usage: ExecutorMemoryUsage {
129 allocated_bytes: 1024 * 1024,
130 used_bytes: 512 * 1024,
131 peak_bytes: 1024 * 1024,
132 utilization_percent: 50.0,
133 },
134 last_operation: None,
135 }
136 }
137}
138
139fn create_stub_kv_cache(
141 request_id: ferrum_types::RequestId,
142 num_layers: usize,
143 seq_len: usize,
144) -> Arc<dyn KvCacheHandle> {
145 #[derive(Debug)]
146 struct StubKvCache {
147 request_id: ferrum_types::RequestId,
148 block_table: BlockTable,
149 num_layers: usize,
150 }
151
152 impl KvCacheHandle for StubKvCache {
153 fn block_table(&self) -> &BlockTable {
154 &self.block_table
155 }
156
157 fn block_table_mut(&mut self) -> &mut BlockTable {
158 &mut self.block_table
159 }
160
161 fn as_any(&self) -> &dyn std::any::Any {
162 self
163 }
164
165 fn device(&self) -> Device {
166 Device::CPU
167 }
168
169 fn num_layers(&self) -> usize {
170 self.num_layers
171 }
172
173 fn num_heads(&self) -> usize {
174 12
175 }
176
177 fn head_dim(&self) -> usize {
178 64
179 }
180
181 fn key_cache(&self, _layer: usize) -> Result<Option<ferrum_interfaces::TensorRef>> {
182 Ok(None)
183 }
184
185 fn value_cache(&self, _layer: usize) -> Result<Option<ferrum_interfaces::TensorRef>> {
186 Ok(None)
187 }
188
189 fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
190 Err(ferrum_types::FerrumError::unsupported("Stub cache clone"))
191 }
192
193 fn stats(&self) -> ferrum_interfaces::kv_cache::CacheHandleStats {
194 ferrum_interfaces::kv_cache::CacheHandleStats {
195 memory_bytes: 1024,
196 blocks_allocated: 1,
197 tokens_stored: self.block_table.sequence_length,
198 utilization: 0.5,
199 last_access: std::time::Instant::now(),
200 }
201 }
202
203 fn is_valid(&self) -> bool {
204 true
205 }
206
207 fn cache_id(&self) -> String {
208 format!("stub_{}", self.request_id.to_string())
209 }
210 }
211
212 let mut block_table = BlockTable::new(16);
213 block_table.sequence_length = seq_len;
214
215 Arc::new(StubKvCache {
216 request_id,
217 block_table,
218 num_layers,
219 })
220}