1#![cfg(feature = "server")]
8
9use std::sync::Arc;
10
11use tokio::sync::{mpsc, Mutex};
12
13use crate::backend::Backend;
14use crate::model::{InferenceContext, Model, ModelConfig};
15use crate::sampling::{Sampler, SamplerConfig};
16use crate::tokenizer::Tokenizer;
17
18#[derive(Debug, Clone)]
23pub struct BatchedEngineConfig {
24 pub max_batch_size: usize,
26 pub max_seq_len: usize,
28 pub max_queue_depth: usize,
30}
31
32impl Default for BatchedEngineConfig {
33 fn default() -> Self {
34 Self {
35 max_batch_size: 8,
36 max_seq_len: 4096,
37 max_queue_depth: 64,
38 }
39 }
40}
41
42pub struct BatchRequest {
48 pub tokens: Vec<u32>,
50 pub max_tokens: usize,
52 pub sampler_config: SamplerConfig,
54 pub token_sender: mpsc::Sender<BatchToken>,
56}
57
58#[derive(Debug, Clone)]
60pub enum BatchToken {
61 Token { id: u32, text: String },
63 Done {
65 reason: BatchFinishReason,
66 prompt_tokens: usize,
67 completion_tokens: usize,
68 },
69 Error(String),
71}
72
73#[derive(Debug, Clone)]
75pub enum BatchFinishReason {
76 Stop,
77 MaxTokens,
78 Error,
79}
80
81struct ActiveSequence {
87 tokens: Vec<u32>,
89 prompt_len: usize,
91 generated: usize,
93 max_tokens: usize,
95 ctx: InferenceContext,
97 sampler: Sampler,
99 sender: mpsc::Sender<BatchToken>,
101}
102
103enum BatchCommand {
105 Request(BatchRequest),
106 Shutdown,
107}
108
109pub struct BatchedEngine {
115 config: BatchedEngineConfig,
116 request_tx: mpsc::Sender<BatchCommand>,
118 queue_count: Arc<Mutex<usize>>,
120 _handle: Option<tokio::task::JoinHandle<()>>,
122}
123
124impl BatchedEngine {
125 pub fn new(
127 model: Arc<dyn Model>,
128 tokenizer: Arc<Tokenizer>,
129 _model_config: ModelConfig,
130 backend: Arc<dyn Backend>,
131 config: BatchedEngineConfig,
132 ) -> Self {
133 let (request_tx, mut request_rx) = mpsc::channel(config.max_queue_depth);
134 let queue_count = Arc::new(Mutex::new(0));
135
136 let model_clone = model.clone();
137 let tokenizer_clone = tokenizer.clone();
138 let backend_clone = backend.clone();
139 let queue_count_clone = queue_count.clone();
140 let max_batch_size = config.max_batch_size;
141 let max_seq_len = config.max_seq_len;
142 let eos_token_id = tokenizer.special_tokens.eos_token_id;
143
144 let handle = tokio::spawn(async move {
145 run_background_loop(
146 model_clone,
147 tokenizer_clone,
148 backend_clone,
149 &mut request_rx,
150 queue_count_clone,
151 max_batch_size,
152 max_seq_len,
153 eos_token_id,
154 )
155 .await;
156 });
157
158 Self {
159 config,
160 request_tx,
161 queue_count,
162 _handle: Some(handle),
163 }
164 }
165
166 pub fn submit(&self, request: BatchRequest) -> Result<(), String> {
168 let mut count = self
169 .queue_count
170 .try_lock()
171 .map_err(|_| "failed to lock queue")?;
172
173 if *count >= self.config.max_queue_depth {
174 return Err("queue full".to_string());
175 }
176
177 *count += 1;
178 drop(count);
179
180 self.request_tx
181 .try_send(BatchCommand::Request(request))
182 .map_err(|e| {
183 if let Ok(mut c) = self.queue_count.try_lock() {
185 *c = c.saturating_sub(1);
186 }
187 e.to_string()
188 })?;
189
190 Ok(())
191 }
192
193 pub fn shutdown(&self) {
195 let _ = self.request_tx.try_send(BatchCommand::Shutdown);
196 }
197}
198
199async fn run_background_loop(
201 model: Arc<dyn Model>,
202 tokenizer: Arc<Tokenizer>,
203 backend: Arc<dyn Backend>,
204 request_rx: &mut mpsc::Receiver<BatchCommand>,
205 queue_count: Arc<Mutex<usize>>,
206 max_batch_size: usize,
207 max_seq_len: usize,
208 eos_token_id: u32,
209) {
210 let mut active: Vec<ActiveSequence> = Vec::with_capacity(max_batch_size);
211 let mut pending: Vec<BatchRequest> = Vec::new();
212 let mut shutdown = false;
213
214 while !shutdown {
215 while let Ok(cmd) = request_rx.try_recv() {
217 match cmd {
218 BatchCommand::Request(req) => {
219 if active.len() < max_batch_size {
220 if let Some(seq) = create_active_sequence(
221 req,
222 &model,
223 &tokenizer,
224 &backend,
225 max_seq_len,
226 ) {
227 active.push(seq);
228 } else {
229 decrement_queue(&queue_count).await;
230 }
231 } else {
232 pending.push(req);
233 }
234 }
235 BatchCommand::Shutdown => shutdown = true,
236 }
237 }
238
239 let mut i = 0;
241 while i < active.len() {
242 let seq = &mut active[i];
243 let result = step_sequence(seq, &model, &tokenizer, eos_token_id);
244
245 match result {
246 Ok(Some((token_id, text))) => {
247 let _ = seq
248 .sender
249 .send(BatchToken::Token {
250 id: token_id,
251 text,
252 })
253 .await;
254 }
255 Ok(None) => {
256 let prompt_tokens = seq.prompt_len;
258 let completion_tokens = seq.generated;
259 let reason = if seq.generated >= seq.max_tokens {
260 BatchFinishReason::MaxTokens
261 } else {
262 BatchFinishReason::Stop
263 };
264 let sender = seq.sender.clone();
265 active.remove(i);
266 decrement_queue(&queue_count).await;
267 let _ = sender
268 .send(BatchToken::Done {
269 reason,
270 prompt_tokens,
271 completion_tokens,
272 })
273 .await;
274 continue;
275 }
276 Err(e) => {
277 let sender = seq.sender.clone();
278 active.remove(i);
279 decrement_queue(&queue_count).await;
280 let _ = sender
281 .send(BatchToken::Error(e.to_string()))
282 .await;
283 continue;
284 }
285 }
286 i += 1;
287 }
288
289 while active.len() < max_batch_size {
291 match pending.pop() {
292 Some(req) => {
293 if let Some(seq) =
294 create_active_sequence(req, &model, &tokenizer, &backend, max_seq_len)
295 {
296 active.push(seq);
297 } else {
298 decrement_queue(&queue_count).await;
299 }
300 }
301 None => break,
302 }
303 }
304
305 if shutdown {
306 break;
307 }
308
309 if active.is_empty() {
311 match tokio::time::timeout(
312 std::time::Duration::from_millis(10),
313 request_rx.recv(),
314 )
315 .await
316 {
317 Ok(Some(BatchCommand::Request(req))) => {
318 if let Some(seq) =
319 create_active_sequence(req, &model, &tokenizer, &backend, max_seq_len)
320 {
321 active.push(seq);
322 } else {
323 decrement_queue(&queue_count).await;
324 }
325 }
326 Ok(Some(BatchCommand::Shutdown)) => break,
327 Ok(None) => break,
328 Err(_) => {}
329 }
330 }
331 }
332}
333
334async fn decrement_queue(queue_count: &Arc<Mutex<usize>>) {
335 let mut c = queue_count.lock().await;
336 *c = c.saturating_sub(1);
337}
338
339fn create_active_sequence(
340 req: BatchRequest,
341 model: &Arc<dyn Model>,
342 _tokenizer: &Arc<Tokenizer>,
343 backend: &Arc<dyn Backend>,
344 max_seq_len: usize,
345) -> Option<ActiveSequence> {
346 if req.tokens.is_empty() {
347 let _ = req.token_sender.try_send(BatchToken::Error(
348 "empty prompt".to_string(),
349 ));
350 return None;
351 }
352
353 let prompt_len = req.tokens.len().min(max_seq_len.saturating_sub(1));
354 let tokens: Vec<u32> = req.tokens.iter().take(prompt_len).copied().collect();
355 let prompt_len = tokens.len();
356
357 let ctx = model.create_context(backend.clone());
358 let sampler = Sampler::new(req.sampler_config.clone(), model.vocab_size());
359
360 Some(ActiveSequence {
361 tokens: tokens.clone(),
362 prompt_len,
363 generated: 0,
364 max_tokens: req.max_tokens,
365 ctx,
366 sampler,
367 sender: req.token_sender,
368 })
369}
370
371fn step_sequence(
374 seq: &mut ActiveSequence,
375 model: &Arc<dyn Model>,
376 tokenizer: &Arc<Tokenizer>,
377 eos_token_id: u32,
378) -> Result<Option<(u32, String)>, crate::model::ModelError> {
379 if let Some(&last) = seq.tokens.last() {
381 if last == eos_token_id {
382 return Ok(None);
383 }
384 }
385
386 if seq.generated >= seq.max_tokens {
387 return Ok(None);
388 }
389
390 let input_tokens: &[u32] = if seq.ctx.position == 0 {
391 &seq.tokens[..]
392 } else {
393 &seq.tokens[seq.tokens.len().saturating_sub(1)..]
394 };
395
396 let logits = model.forward(input_tokens, &mut seq.ctx)?;
397 let next_token = seq.sampler.sample(&logits, &seq.tokens);
398
399 seq.tokens.push(next_token);
400 seq.generated += 1;
401
402 if next_token == eos_token_id {
403 return Ok(None);
404 }
405
406 let text = tokenizer
407 .decode_token(next_token)
408 .unwrap_or_else(|_| String::new());
409
410 Ok(Some((next_token, text)))
411}
412
413#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_batched_engine_config_default() {
423 let config = BatchedEngineConfig::default();
424 assert_eq!(config.max_batch_size, 8);
425 assert_eq!(config.max_seq_len, 4096);
426 assert_eq!(config.max_queue_depth, 64);
427 }
428
429 #[test]
430 fn test_batch_request_creation() {
431 let (tx, _rx) = mpsc::channel(1);
432 let req = BatchRequest {
433 tokens: vec![1, 2, 3],
434 max_tokens: 64,
435 sampler_config: SamplerConfig::default(),
436 token_sender: tx,
437 };
438 assert_eq!(req.tokens.len(), 3);
439 assert_eq!(req.max_tokens, 64);
440 }
441
442 #[test]
443 fn test_batch_finish_reason() {
444 let stop = BatchFinishReason::Stop;
445 let max = BatchFinishReason::MaxTokens;
446 let err = BatchFinishReason::Error;
447
448 match &stop {
449 BatchFinishReason::Stop => {}
450 _ => panic!("expected Stop"),
451 }
452 match &max {
453 BatchFinishReason::MaxTokens => {}
454 _ => panic!("expected MaxTokens"),
455 }
456 match &err {
457 BatchFinishReason::Error => {}
458 _ => panic!("expected Error"),
459 }
460 }
461}