Skip to main content

ferrum_models/executor/
stub_executor.rs

1//! Stub model executor for MVP testing and development
2
3use async_trait::async_trait;
4use ferrum_interfaces::{
5    model_executor::{
6        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
7        ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
8    },
9    BlockTable, ComputeBackend, KvCacheHandle, ModelExecutor,
10};
11use ferrum_types::{DataType, Device, ModelInfo, ModelType, Result};
12use std::sync::Arc;
13use tracing::debug;
14
15/// Stub model executor - MVP implementation
16///
17/// Returns dummy outputs to allow pipeline testing without real models.
18pub struct StubModelExecutor {
19    info: ModelInfo,
20    compute_backend: Arc<dyn ComputeBackend>,
21}
22
23impl StubModelExecutor {
24    pub fn new(
25        model_id: impl Into<ferrum_types::ModelId>,
26        vocab_size: usize,
27        compute_backend: Arc<dyn ComputeBackend>,
28    ) -> Self {
29        let info = ModelInfo {
30            model_id: model_id.into(),
31            model_type: ModelType::Custom("stub".into()),
32            num_parameters: 1_000_000,
33            hidden_size: 768,
34            num_layers: 12,
35            num_heads: 12,
36            num_kv_heads: 12,
37            vocab_size,
38            max_sequence_length: 2048,
39            dtype: DataType::FP16,
40            device: Device::CPU,
41            version: Some("mvp-stub".into()),
42            license: Some("Apache-2.0".into()),
43            metadata: std::collections::HashMap::new(),
44        };
45
46        debug!("Created StubModelExecutor: vocab={}", vocab_size);
47
48        Self {
49            info,
50            compute_backend,
51        }
52    }
53}
54
55#[async_trait]
56impl ModelExecutor for StubModelExecutor {
57    fn info(&self) -> &ModelInfo {
58        &self.info
59    }
60
61    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
62        let batch_size = input.batch_size();
63        let seq_len = input.sequence_length();
64        let vocab_size = self.info.vocab_size;
65
66        debug!("Stub prefill: batch={}, seq_len={}", batch_size, seq_len);
67
68        // Create dummy logits
69        let factory = self.compute_backend.tensor_factory();
70        let logits = factory.zeros(
71            &[batch_size, seq_len, vocab_size],
72            DataType::FP32,
73            &self.info.device,
74        )?;
75
76        // Create stub KV cache
77        let kv_cache = create_stub_kv_cache(
78            ferrum_types::RequestId::new(),
79            self.info.num_layers,
80            seq_len,
81        );
82
83        Ok(PrefillOutput::new(logits, kv_cache))
84    }
85
86    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
87        let batch_size = input.batch_size();
88        let vocab_size = self.info.vocab_size;
89
90        debug!("Stub decode: batch={}", batch_size);
91
92        let factory = self.compute_backend.tensor_factory();
93        let logits = factory.zeros(&[batch_size, vocab_size], DataType::FP32, &self.info.device)?;
94
95        Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
96    }
97
98    fn capabilities(&self) -> ExecutorCapabilities {
99        ExecutorCapabilities {
100            max_batch_size: 32,
101            max_sequence_length: 2048,
102            attention_mechanisms: vec![AttentionType::MultiHead],
103            supports_dynamic_batching: false,
104            supports_continuous_batching: false,
105            supports_speculative_decoding: false,
106            supports_tensor_parallelism: false,
107            supports_pipeline_parallelism: false,
108            supported_dtypes: vec![DataType::FP32, DataType::FP16],
109            supported_devices: vec![Device::CPU],
110            memory_requirements: MemoryRequirements {
111                parameter_memory: 4 * 1024 * 1024, // 4MB
112                activation_memory_per_token: 1024,
113                kv_cache_memory_per_token: 512,
114                overhead_memory: 1024 * 1024,
115            },
116        }
117    }
118
119    fn status(&self) -> ExecutorStatus {
120        ExecutorStatus {
121            state: ExecutorState::Ready,
122            is_ready: true,
123            current_batch_size: 0,
124            prefill_operations: 0,
125            decode_operations: 0,
126            avg_prefill_time_ms: 0.0,
127            avg_decode_time_ms: 0.0,
128            memory_usage: ExecutorMemoryUsage {
129                allocated_bytes: 1024 * 1024,
130                used_bytes: 512 * 1024,
131                peak_bytes: 1024 * 1024,
132                utilization_percent: 50.0,
133            },
134            last_operation: None,
135        }
136    }
137}
138
139/// Create stub KV cache handle
140fn create_stub_kv_cache(
141    request_id: ferrum_types::RequestId,
142    num_layers: usize,
143    seq_len: usize,
144) -> Arc<dyn KvCacheHandle> {
145    #[derive(Debug)]
146    struct StubKvCache {
147        request_id: ferrum_types::RequestId,
148        block_table: BlockTable,
149        num_layers: usize,
150    }
151
152    impl KvCacheHandle for StubKvCache {
153        fn block_table(&self) -> &BlockTable {
154            &self.block_table
155        }
156
157        fn block_table_mut(&mut self) -> &mut BlockTable {
158            &mut self.block_table
159        }
160
161        fn as_any(&self) -> &dyn std::any::Any {
162            self
163        }
164
165        fn device(&self) -> Device {
166            Device::CPU
167        }
168
169        fn num_layers(&self) -> usize {
170            self.num_layers
171        }
172
173        fn num_heads(&self) -> usize {
174            12
175        }
176
177        fn head_dim(&self) -> usize {
178            64
179        }
180
181        fn key_cache(&self, _layer: usize) -> Result<Option<ferrum_interfaces::TensorRef>> {
182            Ok(None)
183        }
184
185        fn value_cache(&self, _layer: usize) -> Result<Option<ferrum_interfaces::TensorRef>> {
186            Ok(None)
187        }
188
189        fn clone_handle(&self) -> Result<Arc<dyn KvCacheHandle>> {
190            Err(ferrum_types::FerrumError::unsupported("Stub cache clone"))
191        }
192
193        fn stats(&self) -> ferrum_interfaces::kv_cache::CacheHandleStats {
194            ferrum_interfaces::kv_cache::CacheHandleStats {
195                memory_bytes: 1024,
196                blocks_allocated: 1,
197                tokens_stored: self.block_table.sequence_length,
198                utilization: 0.5,
199                last_access: std::time::Instant::now(),
200            }
201        }
202
203        fn is_valid(&self) -> bool {
204            true
205        }
206
207        fn cache_id(&self) -> String {
208            format!("stub_{}", self.request_id.to_string())
209        }
210    }
211
212    let mut block_table = BlockTable::new(16);
213    block_table.sequence_length = seq_len;
214
215    Arc::new(StubKvCache {
216        request_id,
217        block_table,
218        num_layers,
219    })
220}