Skip to main content

ferrum_testkit/
configurable_executor.rs

1//! Configurable model executor for testing stop sequences, EOS, and specific token patterns.
2//!
3//! Unlike MockModelExecutor (always biases token 42), this executor can be configured
4//! to produce specific token sequences, emit EOS after N tokens, etc.
5
6use crate::kv_cache::MockKvCacheHandle;
7use crate::tensor::MockTensor;
8use async_trait::async_trait;
9use ferrum_interfaces::{
10    model_executor::{
11        AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
12        ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
13    },
14    ModelExecutor,
15};
16use ferrum_types::{DataType, Device, ModelInfo, ModelType, RequestId, Result};
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20
21/// Model executor that produces a configurable sequence of tokens.
22pub struct ConfigurableModelExecutor {
23    info: ModelInfo,
24    /// Token sequence to cycle through on decode.
25    token_sequence: Vec<u32>,
26    /// If set, emit this token (EOS) after this many decode steps.
27    eos_after: Option<usize>,
28    /// EOS token ID.
29    eos_token: u32,
30    decode_count: AtomicU64,
31}
32
33impl ConfigurableModelExecutor {
34    /// Create executor that cycles through the given token sequence.
35    pub fn with_token_sequence(vocab_size: usize, tokens: Vec<u32>) -> Self {
36        Self {
37            info: mock_info(vocab_size),
38            token_sequence: tokens,
39            eos_after: None,
40            eos_token: 2, // common EOS
41            decode_count: AtomicU64::new(0),
42        }
43    }
44
45    /// Create executor that emits EOS after `n` decode steps.
46    pub fn with_eos_after(vocab_size: usize, n: usize, eos_token: u32) -> Self {
47        Self {
48            info: mock_info(vocab_size),
49            token_sequence: vec![42], // default token before EOS
50            eos_after: Some(n),
51            eos_token,
52            decode_count: AtomicU64::new(0),
53        }
54    }
55
56    fn next_token_logits(&self) -> Vec<f32> {
57        let step = self.decode_count.load(Ordering::Relaxed) as usize;
58        let vocab_size = self.info.vocab_size;
59        let mut logits = vec![0.0f32; vocab_size];
60
61        // Check if we should emit EOS
62        if let Some(eos_n) = self.eos_after {
63            if step >= eos_n {
64                if (self.eos_token as usize) < vocab_size {
65                    logits[self.eos_token as usize] = 10.0;
66                }
67                return logits;
68            }
69        }
70
71        // Cycle through token sequence
72        let token = self.token_sequence[step % self.token_sequence.len()];
73        if (token as usize) < vocab_size {
74            logits[token as usize] = 10.0;
75        }
76        logits
77    }
78}
79
80#[async_trait]
81impl ModelExecutor for ConfigurableModelExecutor {
82    fn info(&self) -> &ModelInfo {
83        &self.info
84    }
85
86    async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
87        let batch_size = input.batch_size();
88        let seq_len = input.sequence_length();
89        let vocab_size = self.info.vocab_size;
90
91        // Prefill logits: bias first token from sequence
92        let token = self.token_sequence[0];
93        let mut logits_data = vec![0.0f32; batch_size * seq_len * vocab_size];
94        for b in 0..batch_size {
95            for s in 0..seq_len {
96                let offset = (b * seq_len + s) * vocab_size;
97                if offset + token as usize >= logits_data.len() {
98                    continue;
99                }
100                logits_data[offset + token as usize] = 10.0;
101            }
102        }
103        let logits =
104            MockTensor::from_f32(logits_data, &[batch_size, seq_len, vocab_size]).into_ref();
105        let kv_cache = Arc::new(MockKvCacheHandle::new(
106            RequestId::new(),
107            self.info.num_layers,
108            seq_len,
109        ));
110        Ok(PrefillOutput::new(logits, kv_cache))
111    }
112
113    async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
114        let batch_size = input.batch_size();
115        let vocab_size = self.info.vocab_size;
116
117        let single_logits = self.next_token_logits();
118        self.decode_count.fetch_add(1, Ordering::Relaxed);
119
120        // Replicate for batch
121        let mut logits_data = Vec::with_capacity(batch_size * vocab_size);
122        for _ in 0..batch_size {
123            logits_data.extend_from_slice(&single_logits);
124        }
125        let logits = MockTensor::from_f32(logits_data, &[batch_size, vocab_size]).into_ref();
126        Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
127    }
128
129    fn capabilities(&self) -> ExecutorCapabilities {
130        ExecutorCapabilities {
131            max_batch_size: 256,
132            max_sequence_length: 4096,
133            attention_mechanisms: vec![AttentionType::MultiHead],
134            supports_dynamic_batching: true,
135            supports_continuous_batching: true,
136            supports_speculative_decoding: false,
137            supports_tensor_parallelism: false,
138            supports_pipeline_parallelism: false,
139            supported_dtypes: vec![DataType::FP32],
140            supported_devices: vec![Device::CPU],
141            memory_requirements: MemoryRequirements {
142                parameter_memory: 0,
143                activation_memory_per_token: 0,
144                kv_cache_memory_per_token: 0,
145                overhead_memory: 0,
146            },
147        }
148    }
149
150    fn status(&self) -> ExecutorStatus {
151        ExecutorStatus {
152            state: ExecutorState::Ready,
153            is_ready: true,
154            current_batch_size: 0,
155            prefill_operations: 0,
156            decode_operations: self.decode_count.load(Ordering::Relaxed),
157            avg_prefill_time_ms: 0.0,
158            avg_decode_time_ms: 0.0,
159            memory_usage: ExecutorMemoryUsage {
160                allocated_bytes: 0,
161                used_bytes: 0,
162                peak_bytes: 0,
163                utilization_percent: 0.0,
164            },
165            last_operation: None,
166        }
167    }
168}
169
170fn mock_info(vocab_size: usize) -> ModelInfo {
171    ModelInfo {
172        model_id: "configurable-mock".into(),
173        model_type: ModelType::Custom("mock".into()),
174        num_parameters: 1_000_000,
175        hidden_size: 768,
176        num_layers: 12,
177        num_heads: 12,
178        num_kv_heads: 12,
179        vocab_size,
180        max_sequence_length: 4096,
181        dtype: DataType::FP32,
182        device: Device::CPU,
183        version: Some("configurable-1.0".into()),
184        license: None,
185        metadata: HashMap::new(),
186    }
187}