Skip to main content

ferrum_models/executor/
stub_executor.rs

1//! Stub model executor for MVP testing and development
2
3use async_trait::async_trait;
4use ferrum_interfaces::{
5    model_executor::{
6        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
7        ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
8    },
9    BlockTable, KvCacheHandle, ModelExecutor, TensorFactory,
10};
11use ferrum_types::{DataType, Device, ModelInfo, ModelType, Result};
12use std::sync::Arc;
13use tracing::debug;
14
15/// Stub model executor - MVP implementation
16///
17/// Returns dummy outputs to allow pipeline testing without real models.
18pub struct StubModelExecutor {
19    info: ModelInfo,
20    tensor_factory: Arc<dyn TensorFactory>,
21}
22
23impl StubModelExecutor {
24    pub fn new(
25        model_id: impl Into<ferrum_types::ModelId>,
26        vocab_size: usize,
27        tensor_factory: Arc<dyn TensorFactory>,
28    ) -> Self {
29        let info = ModelInfo {
30            model_id: model_id.into(),
31            model_type: ModelType::Custom("stub".into()),
32            num_parameters: 1_000_000,
33            hidden_size: 768,
34            num_layers: 12,
35            num_heads: 12,
36            num_kv_heads: 12,
37            vocab_size,
38            max_sequence_length: 2048,
39            dtype: DataType::FP16,
40            device: Device::CPU,
41            version: Some("mvp-stub".into()),
42            license: Some("Apache-2.0".into()),
43            metadata: std::collections::HashMap::new(),
44        };
45
46        debug!("Created StubModelExecutor: vocab={}", vocab_size);
47
48        Self {
49            info,
50            tensor_factory,
51        }
52    }
53}
54
55#[async_trait]
56impl ModelExecutor for StubModelExecutor {
57    fn info(&self) -> &ModelInfo {
58        &self.info
59    }
60
61    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
62        let batch_size = input.batch_size();
63        let seq_len = input.sequence_length();
64        let vocab_size = self.info.vocab_size;
65
66        debug!("Stub prefill: batch={}, seq_len={}", batch_size, seq_len);
67
68        // Create dummy logits
69        let logits = self.tensor_factory.zeros(
70            &[batch_size, seq_len, vocab_size],
71            DataType::FP32,
72            &self.info.device,
73        )?;
74
75        // Create stub KV cache
76        let kv_cache = create_stub_kv_cache(
77            ferrum_types::RequestId::new(),
78            self.info.num_layers,
79            seq_len,
80        );
81
82        Ok(PrefillOutput::new(logits, kv_cache))
83    }
84
85    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
86        let batch_size = input.batch_size();
87        let vocab_size = self.info.vocab_size;
88
89        debug!("Stub decode: batch={}", batch_size);
90
91        let logits = self.tensor_factory.zeros(
92            &[batch_size, vocab_size],
93            DataType::FP32,
94            &self.info.device,
95        )?;
96
97        Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
98    }
99
100    fn capabilities(&self) -> ExecutorCapabilities {
101        ExecutorCapabilities {
102            max_batch_size: 32,
103            max_sequence_length: 2048,
104            attention_mechanisms: vec![AttentionType::MultiHead],
105            supports_dynamic_batching: false,
106            supports_continuous_batching: false,
107            supports_speculative_decoding: false,
108            supports_tensor_parallelism: false,
109            supports_pipeline_parallelism: false,
110            supported_dtypes: vec![DataType::FP32, DataType::FP16],
111            supported_devices: vec![Device::CPU],
112            memory_requirements: MemoryRequirements {
113                parameter_memory: 4 * 1024 * 1024, // 4MB
114                activation_memory_per_token: 1024,
115                kv_cache_memory_per_token: 512,
116                overhead_memory: 1024 * 1024,
117            },
118        }
119    }
120
121    fn status(&self) -> ExecutorStatus {
122        ExecutorStatus {
123            state: ExecutorState::Ready,
124            is_ready: true,
125            current_batch_size: 0,
126            prefill_operations: 0,
127            decode_operations: 0,
128            avg_prefill_time_ms: 0.0,
129            avg_decode_time_ms: 0.0,
130            memory_usage: ExecutorMemoryUsage {
131                allocated_bytes: 1024 * 1024,
132                used_bytes: 512 * 1024,
133                peak_bytes: 1024 * 1024,
134                utilization_percent: 50.0,
135            },
136            last_operation: None,
137        }
138    }
139}
140
141/// Create stub KV cache handle
142fn create_stub_kv_cache(
143    request_id: ferrum_types::RequestId,
144    num_layers: usize,
145    seq_len: usize,
146) -> Arc<dyn KvCacheHandle> {
147    #[derive(Debug)]
148    struct StubKvCache {
149        request_id: ferrum_types::RequestId,
150        block_table: BlockTable,
151        num_layers: usize,
152    }
153
154    impl KvCacheHandle for StubKvCache {
155        fn block_table(&self) -> &BlockTable {
156            &self.block_table
157        }
158
159        fn block_table_mut(&mut self) -> &mut BlockTable {
160            &mut self.block_table
161        }
162
163        fn as_any(&self) -> &dyn std::any::Any {
164            self
165        }
166
167        fn device(&self) -> Device {
168            Device::CPU
169        }
170
171        fn num_layers(&self) -> usize {
172            self.num_layers
173        }
174
175        fn num_heads(&self) -> usize {
176            12
177        }
178
179        fn head_dim(&self) -> usize {
180            64
181        }
182
183        fn key_cache(&self, _layer: usize) -> Result<Option<ferrum_interfaces::TensorRef>> {
184            Ok(None)
185        }
186
187        fn value_cache(&self, _layer: usize) -> Result<Option<ferrum_interfaces::TensorRef>> {
188            Ok(None)
189        }
190
191        fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
192            Err(ferrum_types::FerrumError::unsupported("Stub cache clone"))
193        }
194
195        fn stats(&self) -> ferrum_interfaces::kv_cache::CacheHandleStats {
196            ferrum_interfaces::kv_cache::CacheHandleStats {
197                memory_bytes: 1024,
198                blocks_allocated: 1,
199                tokens_stored: self.block_table.sequence_length,
200                utilization: 0.5,
201                last_access: std::time::Instant::now(),
202            }
203        }
204
205        fn is_valid(&self) -> bool {
206            true
207        }
208
209        fn cache_id(&self) -> String {
210            format!("stub_{}", self.request_id.to_string())
211        }
212    }
213
214    let mut block_table = BlockTable::new(16);
215    block_table.sequence_length = seq_len;
216
217    Arc::new(StubKvCache {
218        request_id,
219        block_table,
220        num_layers,
221    })
222}