Skip to main content

llama_rs/
engine_batched.rs

1//! Batched inference engine for continuous batching
2//!
3//! Provides `BatchedEngine` that processes multiple generation requests
4//! concurrently using continuous batching. Each iteration processes one
5//! token per active sequence, maximizing GPU utilization.
6
7#![cfg(feature = "server")]
8
9use std::sync::Arc;
10
11use tokio::sync::{mpsc, Mutex};
12
13use crate::backend::Backend;
14use crate::model::{InferenceContext, Model, ModelConfig};
15use crate::sampling::{Sampler, SamplerConfig};
16use crate::tokenizer::Tokenizer;
17
18// ============================================================================
19// Config
20// ============================================================================
21
22#[derive(Debug, Clone)]
23pub struct BatchedEngineConfig {
24    /// Maximum concurrent sequences
25    pub max_batch_size: usize,
26    /// Maximum sequence length
27    pub max_seq_len: usize,
28    /// Maximum queued requests (beyond this, reject)
29    pub max_queue_depth: usize,
30}
31
32impl Default for BatchedEngineConfig {
33    fn default() -> Self {
34        Self {
35            max_batch_size: 8,
36            max_seq_len: 4096,
37            max_queue_depth: 64,
38        }
39    }
40}
41
42// ============================================================================
43// Request / Response types
44// ============================================================================
45
46/// A request submitted to the batched engine
47pub struct BatchRequest {
48    /// Prompt tokens
49    pub tokens: Vec<u32>,
50    /// Maximum tokens to generate
51    pub max_tokens: usize,
52    /// Sampler configuration
53    pub sampler_config: SamplerConfig,
54    /// Channel to receive generated tokens
55    pub token_sender: mpsc::Sender<BatchToken>,
56}
57
58/// Token event from the batched engine
59#[derive(Debug, Clone)]
60pub enum BatchToken {
61    /// A generated token
62    Token { id: u32, text: String },
63    /// Generation finished
64    Done {
65        reason: BatchFinishReason,
66        prompt_tokens: usize,
67        completion_tokens: usize,
68    },
69    /// Error occurred
70    Error(String),
71}
72
73/// Reason for finishing
74#[derive(Debug, Clone)]
75pub enum BatchFinishReason {
76    Stop,
77    MaxTokens,
78    Error,
79}
80
81// ============================================================================
82// Internal state
83// ============================================================================
84
85/// Internal state for an active sequence in the batch
86struct ActiveSequence {
87    /// All tokens so far (prompt + generated)
88    tokens: Vec<u32>,
89    /// Prompt length
90    prompt_len: usize,
91    /// Number of generated tokens
92    generated: usize,
93    /// Maximum tokens to generate
94    max_tokens: usize,
95    /// Inference context with KV cache
96    ctx: InferenceContext,
97    /// Sampler for this sequence
98    sampler: Sampler,
99    /// Channel to send results
100    sender: mpsc::Sender<BatchToken>,
101}
102
103/// Command for the background loop
104enum BatchCommand {
105    Request(BatchRequest),
106    Shutdown,
107}
108
109// ============================================================================
110// BatchedEngine
111// ============================================================================
112
113/// Batched inference engine using continuous batching
114pub struct BatchedEngine {
115    config: BatchedEngineConfig,
116    /// Channel to submit new requests
117    request_tx: mpsc::Sender<BatchCommand>,
118    /// Queue depth counter (active + pending)
119    queue_count: Arc<Mutex<usize>>,
120    /// Handle to the background processing loop
121    _handle: Option<tokio::task::JoinHandle<()>>,
122}
123
124impl BatchedEngine {
125    /// Create a new batched engine and spawn the background processing loop.
126    pub fn new(
127        model: Arc<dyn Model>,
128        tokenizer: Arc<Tokenizer>,
129        _model_config: ModelConfig,
130        backend: Arc<dyn Backend>,
131        config: BatchedEngineConfig,
132    ) -> Self {
133        let (request_tx, mut request_rx) = mpsc::channel(config.max_queue_depth);
134        let queue_count = Arc::new(Mutex::new(0));
135
136        let model_clone = model.clone();
137        let tokenizer_clone = tokenizer.clone();
138        let backend_clone = backend.clone();
139        let queue_count_clone = queue_count.clone();
140        let max_batch_size = config.max_batch_size;
141        let max_seq_len = config.max_seq_len;
142        let eos_token_id = tokenizer.special_tokens.eos_token_id;
143
144        let handle = tokio::spawn(async move {
145            run_background_loop(
146                model_clone,
147                tokenizer_clone,
148                backend_clone,
149                &mut request_rx,
150                queue_count_clone,
151                max_batch_size,
152                max_seq_len,
153                eos_token_id,
154            )
155            .await;
156        });
157
158        Self {
159            config,
160            request_tx,
161            queue_count,
162            _handle: Some(handle),
163        }
164    }
165
166    /// Submit a request. Returns error if queue is full.
167    pub fn submit(&self, request: BatchRequest) -> Result<(), String> {
168        let mut count = self
169            .queue_count
170            .try_lock()
171            .map_err(|_| "failed to lock queue")?;
172
173        if *count >= self.config.max_queue_depth {
174            return Err("queue full".to_string());
175        }
176
177        *count += 1;
178        drop(count);
179
180        self.request_tx
181            .try_send(BatchCommand::Request(request))
182            .map_err(|e| {
183                // Decrement on send failure
184                if let Ok(mut c) = self.queue_count.try_lock() {
185                    *c = c.saturating_sub(1);
186                }
187                e.to_string()
188            })?;
189
190        Ok(())
191    }
192
193    /// Signal the background loop to stop.
194    pub fn shutdown(&self) {
195        let _ = self.request_tx.try_send(BatchCommand::Shutdown);
196    }
197}
198
199/// Background loop: receive requests, process active sequences, send results.
200async fn run_background_loop(
201    model: Arc<dyn Model>,
202    tokenizer: Arc<Tokenizer>,
203    backend: Arc<dyn Backend>,
204    request_rx: &mut mpsc::Receiver<BatchCommand>,
205    queue_count: Arc<Mutex<usize>>,
206    max_batch_size: usize,
207    max_seq_len: usize,
208    eos_token_id: u32,
209) {
210    let mut active: Vec<ActiveSequence> = Vec::with_capacity(max_batch_size);
211    let mut pending: Vec<BatchRequest> = Vec::new();
212    let mut shutdown = false;
213
214    while !shutdown {
215        // 1. Drain new requests (non-blocking)
216        while let Ok(cmd) = request_rx.try_recv() {
217            match cmd {
218                BatchCommand::Request(req) => {
219                    if active.len() < max_batch_size {
220                        if let Some(seq) = create_active_sequence(
221                            req,
222                            &model,
223                            &tokenizer,
224                            &backend,
225                            max_seq_len,
226                        ) {
227                            active.push(seq);
228                        } else {
229                            decrement_queue(&queue_count).await;
230                        }
231                    } else {
232                        pending.push(req);
233                    }
234                }
235                BatchCommand::Shutdown => shutdown = true,
236            }
237        }
238
239        // 2. Process each active sequence (one token per iteration)
240        let mut i = 0;
241        while i < active.len() {
242            let seq = &mut active[i];
243            let result = step_sequence(seq, &model, &tokenizer, eos_token_id);
244
245            match result {
246                Ok(Some((token_id, text))) => {
247                    let _ = seq
248                        .sender
249                        .send(BatchToken::Token {
250                            id: token_id,
251                            text,
252                        })
253                        .await;
254                }
255                Ok(None) => {
256                    // Sequence finished
257                    let prompt_tokens = seq.prompt_len;
258                    let completion_tokens = seq.generated;
259                    let reason = if seq.generated >= seq.max_tokens {
260                        BatchFinishReason::MaxTokens
261                    } else {
262                        BatchFinishReason::Stop
263                    };
264                    let sender = seq.sender.clone();
265                    active.remove(i);
266                    decrement_queue(&queue_count).await;
267                    let _ = sender
268                        .send(BatchToken::Done {
269                            reason,
270                            prompt_tokens,
271                            completion_tokens,
272                        })
273                        .await;
274                    continue;
275                }
276                Err(e) => {
277                    let sender = seq.sender.clone();
278                    active.remove(i);
279                    decrement_queue(&queue_count).await;
280                    let _ = sender
281                        .send(BatchToken::Error(e.to_string()))
282                        .await;
283                    continue;
284                }
285            }
286            i += 1;
287        }
288
289        // 3. Promote pending to active when we have space
290        while active.len() < max_batch_size {
291            match pending.pop() {
292                Some(req) => {
293                    if let Some(seq) =
294                        create_active_sequence(req, &model, &tokenizer, &backend, max_seq_len)
295                    {
296                        active.push(seq);
297                    } else {
298                        decrement_queue(&queue_count).await;
299                    }
300                }
301                None => break,
302            }
303        }
304
305        if shutdown {
306            break;
307        }
308
309        // 4. Sleep briefly if no work
310        if active.is_empty() {
311            match tokio::time::timeout(
312                std::time::Duration::from_millis(10),
313                request_rx.recv(),
314            )
315            .await
316            {
317                Ok(Some(BatchCommand::Request(req))) => {
318                    if let Some(seq) =
319                        create_active_sequence(req, &model, &tokenizer, &backend, max_seq_len)
320                    {
321                        active.push(seq);
322                    } else {
323                        decrement_queue(&queue_count).await;
324                    }
325                }
326                Ok(Some(BatchCommand::Shutdown)) => break,
327                Ok(None) => break,
328                Err(_) => {}
329            }
330        }
331    }
332}
333
334async fn decrement_queue(queue_count: &Arc<Mutex<usize>>) {
335    let mut c = queue_count.lock().await;
336    *c = c.saturating_sub(1);
337}
338
339fn create_active_sequence(
340    req: BatchRequest,
341    model: &Arc<dyn Model>,
342    _tokenizer: &Arc<Tokenizer>,
343    backend: &Arc<dyn Backend>,
344    max_seq_len: usize,
345) -> Option<ActiveSequence> {
346    if req.tokens.is_empty() {
347        let _ = req.token_sender.try_send(BatchToken::Error(
348            "empty prompt".to_string(),
349        ));
350        return None;
351    }
352
353    let prompt_len = req.tokens.len().min(max_seq_len.saturating_sub(1));
354    let tokens: Vec<u32> = req.tokens.iter().take(prompt_len).copied().collect();
355    let prompt_len = tokens.len();
356
357    let ctx = model.create_context(backend.clone());
358    let sampler = Sampler::new(req.sampler_config.clone(), model.vocab_size());
359
360    Some(ActiveSequence {
361        tokens: tokens.clone(),
362        prompt_len,
363        generated: 0,
364        max_tokens: req.max_tokens,
365        ctx,
366        sampler,
367        sender: req.token_sender,
368    })
369}
370
371/// Step one sequence: prefill or decode one token. Returns Ok(Some((id, text))),
372/// Ok(None) if done, or Err on model error.
373fn step_sequence(
374    seq: &mut ActiveSequence,
375    model: &Arc<dyn Model>,
376    tokenizer: &Arc<Tokenizer>,
377    eos_token_id: u32,
378) -> Result<Option<(u32, String)>, crate::model::ModelError> {
379    // Check EOS from last token
380    if let Some(&last) = seq.tokens.last() {
381        if last == eos_token_id {
382            return Ok(None);
383        }
384    }
385
386    if seq.generated >= seq.max_tokens {
387        return Ok(None);
388    }
389
390    let input_tokens: &[u32] = if seq.ctx.position == 0 {
391        &seq.tokens[..]
392    } else {
393        &seq.tokens[seq.tokens.len().saturating_sub(1)..]
394    };
395
396    let logits = model.forward(input_tokens, &mut seq.ctx)?;
397    let next_token = seq.sampler.sample(&logits, &seq.tokens);
398
399    seq.tokens.push(next_token);
400    seq.generated += 1;
401
402    if next_token == eos_token_id {
403        return Ok(None);
404    }
405
406    let text = tokenizer
407        .decode_token(next_token)
408        .unwrap_or_else(|_| String::new());
409
410    Ok(Some((next_token, text)))
411}
412
413// ============================================================================
414// Tests
415// ============================================================================
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_batched_engine_config_default() {
423        let config = BatchedEngineConfig::default();
424        assert_eq!(config.max_batch_size, 8);
425        assert_eq!(config.max_seq_len, 4096);
426        assert_eq!(config.max_queue_depth, 64);
427    }
428
429    #[test]
430    fn test_batch_request_creation() {
431        let (tx, _rx) = mpsc::channel(1);
432        let req = BatchRequest {
433            tokens: vec![1, 2, 3],
434            max_tokens: 64,
435            sampler_config: SamplerConfig::default(),
436            token_sender: tx,
437        };
438        assert_eq!(req.tokens.len(), 3);
439        assert_eq!(req.max_tokens, 64);
440    }
441
442    #[test]
443    fn test_batch_finish_reason() {
444        let stop = BatchFinishReason::Stop;
445        let max = BatchFinishReason::MaxTokens;
446        let err = BatchFinishReason::Error;
447
448        match &stop {
449            BatchFinishReason::Stop => {}
450            _ => panic!("expected Stop"),
451        }
452        match &max {
453            BatchFinishReason::MaxTokens => {}
454            _ => panic!("expected MaxTokens"),
455        }
456        match &err {
457            BatchFinishReason::Error => {}
458            _ => panic!("expected Error"),
459        }
460    }
461}