Skip to main content

god_graph/transformer/batch/
mod.rs

1//! Batch inference module for efficient throughput
2//!
3//! This module provides:
4//! - Batch forward pass
5//! - Continuous batching (vLLM-style)
6//! - Request scheduling
7
8use crate::tensor::DenseTensor;
9use super::model::LlamaModel;
10use super::generation::GenerationConfig;
11use super::kv_cache::KVCache;
12
13/// Batch data for inference
14#[derive(Debug, Clone)]
15pub struct BatchData {
16    /// Input token IDs for each sequence in batch
17    pub input_ids: Vec<Vec<usize>>,
18    /// Attention mask [batch_size, seq_len, seq_len]
19    pub attention_mask: Option<DenseTensor>,
20    /// Position IDs [batch_size, seq_len]
21    pub position_ids: Option<Vec<Vec<usize>>>,
22    /// Sequence lengths
23    pub seq_lengths: Vec<usize>,
24}
25
26impl BatchData {
27    /// Create a new batch from input sequences
28    ///
29    /// # Arguments
30    /// * `input_ids` - List of input sequences
31    pub fn new(input_ids: Vec<Vec<usize>>) -> Self {
32        let seq_lengths: Vec<usize> = input_ids.iter().map(|ids| ids.len()).collect();
33        let max_len = seq_lengths.iter().max().copied().unwrap_or(0);
34
35        // Pad sequences to max length
36        let mut padded_ids = Vec::new();
37        for ids in &input_ids {
38            let mut padded = ids.clone();
39            while padded.len() < max_len {
40                padded.push(0); // Pad with 0
41            }
42            padded_ids.push(padded);
43        }
44
45        // Create attention mask
46        let batch_size = input_ids.len();
47        let mut mask_data = Vec::with_capacity(batch_size * max_len * max_len);
48
49        for &seq_len in seq_lengths.iter() {
50            for j in 0..max_len {
51                for k in 0..max_len {
52                    // Valid positions can attend to each other
53                    let can_attend = (j < seq_len && k < seq_len) as u8 as f64;
54                    mask_data.push(if can_attend == 1.0 { 0.0 } else { f64::NEG_INFINITY });
55                }
56            }
57        }
58
59        let attention_mask = Some(DenseTensor::new(mask_data, vec![batch_size, max_len, max_len]));
60
61        Self {
62            input_ids: padded_ids,
63            attention_mask,
64            position_ids: None,
65            seq_lengths,
66        }
67    }
68
69    /// Get batch size
70    pub fn batch_size(&self) -> usize {
71        self.input_ids.len()
72    }
73
74    /// Get maximum sequence length
75    pub fn max_seq_len(&self) -> usize {
76        self.seq_lengths.iter().max().copied().unwrap_or(0)
77    }
78
79    /// Get padded input IDs as 2D vector
80    pub fn padded_input_ids(&self) -> &[Vec<usize>] {
81        &self.input_ids
82    }
83}
84
85/// Inference request
86#[derive(Debug, Clone)]
87pub struct InferenceRequest {
88    /// Request ID
89    pub id: usize,
90    /// Input token IDs
91    pub input_ids: Vec<usize>,
92    /// Generation configuration
93    pub config: GenerationConfig,
94    /// Generated tokens so far
95    pub generated: Vec<usize>,
96    /// Whether request is complete
97    pub completed: bool,
98    /// Priority (lower = higher priority)
99    pub priority: usize,
100}
101
102impl InferenceRequest {
103    /// Create a new inference request
104    pub fn new(id: usize, input_ids: Vec<usize>, config: GenerationConfig) -> Self {
105        Self {
106            id,
107            input_ids: input_ids.clone(),
108            config,
109            generated: input_ids,
110            completed: false,
111            priority: 0,
112        }
113    }
114
115    /// Add generated token
116    pub fn append_token(&mut self, token: usize) {
117        self.generated.push(token);
118
119        // Check completion
120        if self.generated.len() >= self.config.max_length {
121            self.completed = true;
122        }
123        if let Some(eos) = self.config.eos_token_id {
124            if token == eos {
125                self.completed = true;
126            }
127        }
128    }
129
130    /// Get current sequence length
131    pub fn current_len(&self) -> usize {
132        self.generated.len()
133    }
134}
135
136/// Request scheduler for continuous batching
137#[derive(Debug)]
138pub struct RequestScheduler {
139    /// Pending requests
140    pending: Vec<InferenceRequest>,
141    /// Active requests
142    active: Vec<InferenceRequest>,
143    /// Completed requests
144    completed: Vec<InferenceRequest>,
145    /// Next request ID
146    next_id: usize,
147    /// Maximum batch size
148    max_batch_size: usize,
149}
150
151impl RequestScheduler {
152    /// Create a new scheduler
153    ///
154    /// # Arguments
155    /// * `max_batch_size` - Maximum number of concurrent requests
156    pub fn new(max_batch_size: usize) -> Self {
157        Self {
158            pending: Vec::new(),
159            active: Vec::new(),
160            completed: Vec::new(),
161            next_id: 0,
162            max_batch_size,
163        }
164    }
165
166    /// Add a new request
167    pub fn add_request(&mut self, input_ids: Vec<usize>, config: GenerationConfig) -> usize {
168        let id = self.next_id;
169        self.next_id += 1;
170
171        let request = InferenceRequest::new(id, input_ids, config);
172        self.pending.push(request);
173
174        id
175    }
176
177    /// Schedule requests for next batch
178    pub fn schedule(&mut self) -> Vec<&mut InferenceRequest> {
179        // Move completed active requests to completed
180        self.active.retain(|req| {
181            !req.completed
182        });
183
184        // Move pending to active if there's capacity
185        while !self.pending.is_empty() && self.active.len() < self.max_batch_size {
186            let request = self.pending.remove(0);
187            self.active.push(request);
188        }
189
190        // Return mutable references to active requests
191        self.active.iter_mut().collect()
192    }
193
194    /// Get number of pending requests
195    pub fn num_pending(&self) -> usize {
196        self.pending.len()
197    }
198
199    /// Get number of active requests
200    pub fn num_active(&self) -> usize {
201        self.active.len()
202    }
203
204    /// Get number of completed requests
205    pub fn num_completed(&self) -> usize {
206        self.completed.len()
207    }
208
209    /// Remove and return completed requests
210    pub fn pop_completed(&mut self) -> Vec<InferenceRequest> {
211        
212        std::mem::take(&mut self.completed)
213    }
214}
215
216/// Batch inference engine
217#[derive(Debug)]
218pub struct BatchInference<'a> {
219    /// Reference to model
220    model: &'a LlamaModel,
221    /// KV caches for each layer
222    kv_caches: Vec<KVCache>,
223    /// Current batch size
224    batch_size: usize,
225}
226
227impl<'a> BatchInference<'a> {
228    /// Create a new batch inference engine
229    ///
230    /// # Arguments
231    /// * `model` - Reference to LlamaModel
232    /// * `max_batch_size` - Maximum batch size
233    /// * `max_seq_len` - Maximum sequence length
234    pub fn new(model: &'a LlamaModel, max_batch_size: usize, max_seq_len: usize) -> Self {
235        let kv_caches = vec![
236            KVCache::new(
237                model.num_layers(),
238                max_seq_len,
239                model.hidden_dim(),
240                model.config.get_num_key_value_heads(),
241            );
242            max_batch_size
243        ];
244
245        Self {
246            model,
247            kv_caches,
248            batch_size: 0,
249        }
250    }
251
252    /// Run batch forward pass
253    ///
254    /// # Arguments
255    /// * `batch` - Batch data
256    ///
257    /// # Returns
258    /// Logits for each sequence in batch [batch_size, seq_len, vocab_size]
259    pub fn forward(&mut self, batch: &BatchData) -> DenseTensor {
260        let batch_size = batch.batch_size();
261        self.batch_size = batch_size;
262
263        // Run model forward pass with batched input
264        self.model.forward(&batch.input_ids, batch.attention_mask.as_ref())
265    }
266
267    /// Run single step of generation for batch
268    ///
269    /// # Arguments
270    /// * `requests` - Active inference requests
271    ///
272    /// # Returns
273    /// Generated tokens for each request
274    pub fn step(&mut self, requests: &[&mut InferenceRequest]) -> Vec<usize> {
275        // Collect current tokens
276        let input_ids: Vec<Vec<usize>> = requests
277            .iter()
278            .map(|req| vec![*req.generated.last().unwrap()])
279            .collect();
280
281        let batch = BatchData::new(input_ids);
282
283        // Forward pass
284        let logits = self.forward(&batch);
285
286        // Sample tokens
287        let mut tokens = Vec::new();
288        for (i, req) in requests.iter().enumerate() {
289            let seq_len = req.current_len();
290            let token_logits = logits.get_row(i * seq_len + seq_len - 1);
291
292            // Apply temperature
293            let mut probs = token_logits.clone();
294            if req.config.temperature != 1.0 {
295                probs = probs.scale(1.0 / req.config.temperature);
296            }
297
298            // Softmax
299            probs = probs.softmax(-1);
300
301            // Sample or greedy
302            let token = if req.config.do_sample {
303                self.sample_from_probs(probs.data())
304            } else {
305                self.argmax(probs.data())
306            };
307
308            tokens.push(token);
309        }
310
311        tokens
312    }
313
314    /// Run continuous batching generation
315    ///
316    /// # Arguments
317    /// * `scheduler` - Request scheduler
318    ///
319    /// # Returns
320    /// Generated sequences for each request
321    pub fn generate_continuous(&mut self, scheduler: &mut RequestScheduler) -> Vec<Vec<usize>> {
322        let mut results: Vec<Option<Vec<usize>>> = Vec::new();
323
324        // Initialize results
325        for _ in 0..scheduler.next_id {
326            results.push(None);
327        }
328
329        // Generation loop
330        while scheduler.num_active() > 0 || scheduler.num_pending() > 0 {
331            // Schedule requests
332            let mut active_requests = scheduler.schedule();
333
334            if active_requests.is_empty() {
335                break;
336            }
337
338            // Generate step
339            let tokens = self.step(&active_requests);
340
341            // Update requests
342            for (req, token) in active_requests.iter_mut().zip(tokens) {
343                req.append_token(token);
344
345                if req.completed {
346                    // Store result
347                    results[req.id] = Some(req.generated.clone());
348                }
349            }
350        }
351
352        // Collect results
353        results.into_iter().flatten().collect()
354    }
355
356    /// Reset KV caches
357    pub fn reset(&mut self) {
358        for cache in &mut self.kv_caches {
359            cache.reset();
360        }
361    }
362
363    /// Argmax sampling
364    fn argmax(&self, probs: &[f64]) -> usize {
365        probs
366            .iter()
367            .enumerate()
368            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
369            .map(|(i, _)| i)
370            .unwrap_or(0)
371    }
372
373    /// Sample from probability distribution
374    fn sample_from_probs(&self, probs: &[f64]) -> usize {
375        use rand::Rng;
376        let mut rng = rand::thread_rng();
377        let r: f64 = rng.gen();
378
379        let mut cumulative = 0.0;
380        for (i, &prob) in probs.iter().enumerate() {
381            cumulative += prob;
382            if r < cumulative {
383                return i;
384            }
385        }
386
387        probs.len() - 1
388    }
389}
390
391/// Utility functions for batch processing
392pub mod utils {
393    use super::*;
394
395    /// Pad sequences to same length
396    pub fn pad_sequences(sequences: &[Vec<usize>], pad_token: usize) -> (Vec<Vec<usize>>, Vec<usize>) {
397        let max_len = sequences.iter().map(|s| s.len()).max().unwrap_or(0);
398        let mut padded = Vec::new();
399        let mut lengths = Vec::new();
400
401        for seq in sequences {
402            lengths.push(seq.len());
403            let mut padded_seq = seq.clone();
404            while padded_seq.len() < max_len {
405                padded_seq.push(pad_token);
406            }
407            padded.push(padded_seq);
408        }
409
410        (padded, lengths)
411    }
412
413    /// Create attention mask from lengths
414    pub fn create_attention_mask(lengths: &[usize]) -> DenseTensor {
415        let batch_size = lengths.len();
416        let max_len = lengths.iter().max().copied().unwrap_or(0);
417
418        let mut data = Vec::with_capacity(batch_size * max_len * max_len);
419
420        for &seq_len in lengths.iter() {
421            for j in 0..max_len {
422                for k in 0..max_len {
423                    let can_attend = (j < seq_len && k < seq_len) as u8 as f64;
424                    data.push(if can_attend == 1.0 { 0.0 } else { f64::NEG_INFINITY });
425                }
426            }
427        }
428
429        DenseTensor::new(data, vec![batch_size, max_len, max_len])
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::transformer::model::LlamaModel;
437    use crate::transformer::layers::{MultiHeadAttention, FeedForward, RMSNorm};
438    use crate::transformer::loader::LlamaConfig;
439    use crate::tensor::DenseTensor;
440
441    fn create_test_model() -> LlamaModel {
442        let config = LlamaConfig::llama_7b();
443        let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
444
445        let hidden_dim = config.hidden_size;
446        let num_heads = config.num_attention_heads;
447
448        let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
449        let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
450        let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
451        let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
452        let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
453
454        let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
455        let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
456        let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
457        let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
458
459        let input_layernorm = RMSNorm::default(hidden_dim);
460        let post_attention_layernorm = RMSNorm::default(hidden_dim);
461
462        let layer = super::super::model::LlamaDecoderLayer::new(
463            self_attn, mlp, input_layernorm, post_attention_layernorm
464        );
465
466        let layers = vec![layer; 2];
467        let norm = RMSNorm::default(hidden_dim);
468
469        LlamaModel::new(config, embed_tokens, layers, norm, None)
470    }
471
472    #[test]
473    fn test_batch_data_creation() {
474        let input_ids = vec![
475            vec![1, 2, 3],
476            vec![4, 5],
477            vec![6, 7, 8, 9],
478        ];
479
480        let batch = BatchData::new(input_ids.clone());
481
482        assert_eq!(batch.batch_size(), 3);
483        assert_eq!(batch.max_seq_len(), 4);
484        assert_eq!(batch.seq_lengths, vec![3, 2, 4]);
485    }
486
487    #[test]
488    fn test_inference_request() {
489        let config = GenerationConfig::greedy();
490        let mut request = InferenceRequest::new(0, vec![1, 2, 3], config);
491
492        assert!(!request.completed);
493        assert_eq!(request.current_len(), 3);
494
495        request.append_token(4);
496        assert_eq!(request.current_len(), 4);
497    }
498
499    #[test]
500    fn test_request_scheduler() {
501        let mut scheduler = RequestScheduler::new(2);
502
503        let _id1 = scheduler.add_request(vec![1, 2, 3], GenerationConfig::greedy());
504        let _id2 = scheduler.add_request(vec![4, 5], GenerationConfig::greedy());
505        let _id3 = scheduler.add_request(vec![6, 7, 8], GenerationConfig::greedy());
506
507        assert_eq!(scheduler.num_pending(), 3);
508        assert_eq!(scheduler.num_active(), 0);
509
510        let active = scheduler.schedule();
511        assert_eq!(active.len(), 2); // max_batch_size = 2
512        assert_eq!(scheduler.num_pending(), 1);
513        assert_eq!(scheduler.num_active(), 2);
514    }
515
516    #[test]
517    fn test_batch_inference_creation() {
518        let model = create_test_model();
519        let batch_infer = BatchInference::new(&model, 4, 512);
520
521        assert_eq!(batch_infer.kv_caches.len(), 4);
522    }
523
524    #[test]
525    fn test_pad_sequences() {
526        let sequences = vec![
527            vec![1, 2],
528            vec![3, 4, 5],
529            vec![6],
530        ];
531
532        let (padded, lengths) = utils::pad_sequences(&sequences, 0);
533
534        assert_eq!(padded, vec![
535            vec![1, 2, 0],
536            vec![3, 4, 5],
537            vec![6, 0, 0],
538        ]);
539        assert_eq!(lengths, vec![2, 3, 1]);
540    }
541}