1use crate::kv_cache::MockKvCacheHandle;
7use crate::tensor::MockTensor;
8use async_trait::async_trait;
9use ferrum_interfaces::{
10 model_executor::{
11 AttentionType, DecodeInput, DecodeOutput, ExecutorCapabilities, ExecutorMemoryUsage,
12 ExecutorState, ExecutorStatus, MemoryRequirements, PrefillInput, PrefillOutput,
13 },
14 ModelExecutor,
15};
16use ferrum_types::{DataType, Device, ModelInfo, ModelType, RequestId, Result};
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20
21pub struct ConfigurableModelExecutor {
23 info: ModelInfo,
24 token_sequence: Vec<u32>,
26 eos_after: Option<usize>,
28 eos_token: u32,
30 decode_count: AtomicU64,
31}
32
33impl ConfigurableModelExecutor {
34 pub fn with_token_sequence(vocab_size: usize, tokens: Vec<u32>) -> Self {
36 Self {
37 info: mock_info(vocab_size),
38 token_sequence: tokens,
39 eos_after: None,
40 eos_token: 2, decode_count: AtomicU64::new(0),
42 }
43 }
44
45 pub fn with_eos_after(vocab_size: usize, n: usize, eos_token: u32) -> Self {
47 Self {
48 info: mock_info(vocab_size),
49 token_sequence: vec![42], eos_after: Some(n),
51 eos_token,
52 decode_count: AtomicU64::new(0),
53 }
54 }
55
56 fn next_token_logits(&self) -> Vec<f32> {
57 let step = self.decode_count.load(Ordering::Relaxed) as usize;
58 let vocab_size = self.info.vocab_size;
59 let mut logits = vec![0.0f32; vocab_size];
60
61 if let Some(eos_n) = self.eos_after {
63 if step >= eos_n {
64 if (self.eos_token as usize) < vocab_size {
65 logits[self.eos_token as usize] = 10.0;
66 }
67 return logits;
68 }
69 }
70
71 let token = self.token_sequence[step % self.token_sequence.len()];
73 if (token as usize) < vocab_size {
74 logits[token as usize] = 10.0;
75 }
76 logits
77 }
78}
79
80#[async_trait]
81impl ModelExecutor for ConfigurableModelExecutor {
82 fn info(&self) -> &ModelInfo {
83 &self.info
84 }
85
86 async fn prefill(&self, input: &PrefillInput) -> Result<PrefillOutput> {
87 let batch_size = input.batch_size();
88 let seq_len = input.sequence_length();
89 let vocab_size = self.info.vocab_size;
90
91 let token = self.token_sequence[0];
93 let mut logits_data = vec![0.0f32; batch_size * seq_len * vocab_size];
94 for b in 0..batch_size {
95 for s in 0..seq_len {
96 let offset = (b * seq_len + s) * vocab_size;
97 if offset + token as usize >= logits_data.len() {
98 continue;
99 }
100 logits_data[offset + token as usize] = 10.0;
101 }
102 }
103 let logits =
104 MockTensor::from_f32(logits_data, &[batch_size, seq_len, vocab_size]).into_ref();
105 let kv_cache = Arc::new(MockKvCacheHandle::new(
106 RequestId::new(),
107 self.info.num_layers,
108 seq_len,
109 ));
110 Ok(PrefillOutput::new(logits, kv_cache))
111 }
112
113 async fn decode(&self, input: &DecodeInput) -> Result<DecodeOutput> {
114 let batch_size = input.batch_size();
115 let vocab_size = self.info.vocab_size;
116
117 let single_logits = self.next_token_logits();
118 self.decode_count.fetch_add(1, Ordering::Relaxed);
119
120 let mut logits_data = Vec::with_capacity(batch_size * vocab_size);
122 for _ in 0..batch_size {
123 logits_data.extend_from_slice(&single_logits);
124 }
125 let logits = MockTensor::from_f32(logits_data, &[batch_size, vocab_size]).into_ref();
126 Ok(DecodeOutput::new(logits, input.kv_cache.clone()))
127 }
128
129 fn capabilities(&self) -> ExecutorCapabilities {
130 ExecutorCapabilities {
131 max_batch_size: 256,
132 max_sequence_length: 4096,
133 attention_mechanisms: vec![AttentionType::MultiHead],
134 supports_dynamic_batching: true,
135 supports_continuous_batching: true,
136 supports_speculative_decoding: false,
137 supports_tensor_parallelism: false,
138 supports_pipeline_parallelism: false,
139 supported_dtypes: vec![DataType::FP32],
140 supported_devices: vec![Device::CPU],
141 memory_requirements: MemoryRequirements {
142 parameter_memory: 0,
143 activation_memory_per_token: 0,
144 kv_cache_memory_per_token: 0,
145 overhead_memory: 0,
146 },
147 }
148 }
149
150 fn status(&self) -> ExecutorStatus {
151 ExecutorStatus {
152 state: ExecutorState::Ready,
153 is_ready: true,
154 current_batch_size: 0,
155 prefill_operations: 0,
156 decode_operations: self.decode_count.load(Ordering::Relaxed),
157 avg_prefill_time_ms: 0.0,
158 avg_decode_time_ms: 0.0,
159 memory_usage: ExecutorMemoryUsage {
160 allocated_bytes: 0,
161 used_bytes: 0,
162 peak_bytes: 0,
163 utilization_percent: 0.0,
164 },
165 last_operation: None,
166 }
167 }
168}
169
170fn mock_info(vocab_size: usize) -> ModelInfo {
171 ModelInfo {
172 model_id: "configurable-mock".into(),
173 model_type: ModelType::Custom("mock".into()),
174 num_parameters: 1_000_000,
175 hidden_size: 768,
176 num_layers: 12,
177 num_heads: 12,
178 num_kv_heads: 12,
179 vocab_size,
180 max_sequence_length: 4096,
181 dtype: DataType::FP32,
182 device: Device::CPU,
183 version: Some("configurable-1.0".into()),
184 license: None,
185 metadata: HashMap::new(),
186 }
187}