god_graph/transformer/batch/
mod.rs1use crate::tensor::DenseTensor;
9use super::model::LlamaModel;
10use super::generation::GenerationConfig;
11use super::kv_cache::KVCache;
12
13#[derive(Debug, Clone)]
15pub struct BatchData {
16 pub input_ids: Vec<Vec<usize>>,
18 pub attention_mask: Option<DenseTensor>,
20 pub position_ids: Option<Vec<Vec<usize>>>,
22 pub seq_lengths: Vec<usize>,
24}
25
26impl BatchData {
27 pub fn new(input_ids: Vec<Vec<usize>>) -> Self {
32 let seq_lengths: Vec<usize> = input_ids.iter().map(|ids| ids.len()).collect();
33 let max_len = seq_lengths.iter().max().copied().unwrap_or(0);
34
35 let mut padded_ids = Vec::new();
37 for ids in &input_ids {
38 let mut padded = ids.clone();
39 while padded.len() < max_len {
40 padded.push(0); }
42 padded_ids.push(padded);
43 }
44
45 let batch_size = input_ids.len();
47 let mut mask_data = Vec::with_capacity(batch_size * max_len * max_len);
48
49 for &seq_len in seq_lengths.iter() {
50 for j in 0..max_len {
51 for k in 0..max_len {
52 let can_attend = (j < seq_len && k < seq_len) as u8 as f64;
54 mask_data.push(if can_attend == 1.0 { 0.0 } else { f64::NEG_INFINITY });
55 }
56 }
57 }
58
59 let attention_mask = Some(DenseTensor::new(mask_data, vec![batch_size, max_len, max_len]));
60
61 Self {
62 input_ids: padded_ids,
63 attention_mask,
64 position_ids: None,
65 seq_lengths,
66 }
67 }
68
69 pub fn batch_size(&self) -> usize {
71 self.input_ids.len()
72 }
73
74 pub fn max_seq_len(&self) -> usize {
76 self.seq_lengths.iter().max().copied().unwrap_or(0)
77 }
78
79 pub fn padded_input_ids(&self) -> &[Vec<usize>] {
81 &self.input_ids
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct InferenceRequest {
88 pub id: usize,
90 pub input_ids: Vec<usize>,
92 pub config: GenerationConfig,
94 pub generated: Vec<usize>,
96 pub completed: bool,
98 pub priority: usize,
100}
101
102impl InferenceRequest {
103 pub fn new(id: usize, input_ids: Vec<usize>, config: GenerationConfig) -> Self {
105 Self {
106 id,
107 input_ids: input_ids.clone(),
108 config,
109 generated: input_ids,
110 completed: false,
111 priority: 0,
112 }
113 }
114
115 pub fn append_token(&mut self, token: usize) {
117 self.generated.push(token);
118
119 if self.generated.len() >= self.config.max_length {
121 self.completed = true;
122 }
123 if let Some(eos) = self.config.eos_token_id {
124 if token == eos {
125 self.completed = true;
126 }
127 }
128 }
129
130 pub fn current_len(&self) -> usize {
132 self.generated.len()
133 }
134}
135
136#[derive(Debug)]
138pub struct RequestScheduler {
139 pending: Vec<InferenceRequest>,
141 active: Vec<InferenceRequest>,
143 completed: Vec<InferenceRequest>,
145 next_id: usize,
147 max_batch_size: usize,
149}
150
151impl RequestScheduler {
152 pub fn new(max_batch_size: usize) -> Self {
157 Self {
158 pending: Vec::new(),
159 active: Vec::new(),
160 completed: Vec::new(),
161 next_id: 0,
162 max_batch_size,
163 }
164 }
165
166 pub fn add_request(&mut self, input_ids: Vec<usize>, config: GenerationConfig) -> usize {
168 let id = self.next_id;
169 self.next_id += 1;
170
171 let request = InferenceRequest::new(id, input_ids, config);
172 self.pending.push(request);
173
174 id
175 }
176
177 pub fn schedule(&mut self) -> Vec<&mut InferenceRequest> {
179 self.active.retain(|req| {
181 !req.completed
182 });
183
184 while !self.pending.is_empty() && self.active.len() < self.max_batch_size {
186 let request = self.pending.remove(0);
187 self.active.push(request);
188 }
189
190 self.active.iter_mut().collect()
192 }
193
194 pub fn num_pending(&self) -> usize {
196 self.pending.len()
197 }
198
199 pub fn num_active(&self) -> usize {
201 self.active.len()
202 }
203
204 pub fn num_completed(&self) -> usize {
206 self.completed.len()
207 }
208
209 pub fn pop_completed(&mut self) -> Vec<InferenceRequest> {
211
212 std::mem::take(&mut self.completed)
213 }
214}
215
216#[derive(Debug)]
218pub struct BatchInference<'a> {
219 model: &'a LlamaModel,
221 kv_caches: Vec<KVCache>,
223 batch_size: usize,
225}
226
227impl<'a> BatchInference<'a> {
228 pub fn new(model: &'a LlamaModel, max_batch_size: usize, max_seq_len: usize) -> Self {
235 let kv_caches = vec![
236 KVCache::new(
237 model.num_layers(),
238 max_seq_len,
239 model.hidden_dim(),
240 model.config.get_num_key_value_heads(),
241 );
242 max_batch_size
243 ];
244
245 Self {
246 model,
247 kv_caches,
248 batch_size: 0,
249 }
250 }
251
252 pub fn forward(&mut self, batch: &BatchData) -> DenseTensor {
260 let batch_size = batch.batch_size();
261 self.batch_size = batch_size;
262
263 self.model.forward(&batch.input_ids, batch.attention_mask.as_ref())
265 }
266
267 pub fn step(&mut self, requests: &[&mut InferenceRequest]) -> Vec<usize> {
275 let input_ids: Vec<Vec<usize>> = requests
277 .iter()
278 .map(|req| vec![*req.generated.last().unwrap()])
279 .collect();
280
281 let batch = BatchData::new(input_ids);
282
283 let logits = self.forward(&batch);
285
286 let mut tokens = Vec::new();
288 for (i, req) in requests.iter().enumerate() {
289 let seq_len = req.current_len();
290 let token_logits = logits.get_row(i * seq_len + seq_len - 1);
291
292 let mut probs = token_logits.clone();
294 if req.config.temperature != 1.0 {
295 probs = probs.scale(1.0 / req.config.temperature);
296 }
297
298 probs = probs.softmax(-1);
300
301 let token = if req.config.do_sample {
303 self.sample_from_probs(probs.data())
304 } else {
305 self.argmax(probs.data())
306 };
307
308 tokens.push(token);
309 }
310
311 tokens
312 }
313
314 pub fn generate_continuous(&mut self, scheduler: &mut RequestScheduler) -> Vec<Vec<usize>> {
322 let mut results: Vec<Option<Vec<usize>>> = Vec::new();
323
324 for _ in 0..scheduler.next_id {
326 results.push(None);
327 }
328
329 while scheduler.num_active() > 0 || scheduler.num_pending() > 0 {
331 let mut active_requests = scheduler.schedule();
333
334 if active_requests.is_empty() {
335 break;
336 }
337
338 let tokens = self.step(&active_requests);
340
341 for (req, token) in active_requests.iter_mut().zip(tokens) {
343 req.append_token(token);
344
345 if req.completed {
346 results[req.id] = Some(req.generated.clone());
348 }
349 }
350 }
351
352 results.into_iter().flatten().collect()
354 }
355
356 pub fn reset(&mut self) {
358 for cache in &mut self.kv_caches {
359 cache.reset();
360 }
361 }
362
363 fn argmax(&self, probs: &[f64]) -> usize {
365 probs
366 .iter()
367 .enumerate()
368 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
369 .map(|(i, _)| i)
370 .unwrap_or(0)
371 }
372
373 fn sample_from_probs(&self, probs: &[f64]) -> usize {
375 use rand::Rng;
376 let mut rng = rand::thread_rng();
377 let r: f64 = rng.gen();
378
379 let mut cumulative = 0.0;
380 for (i, &prob) in probs.iter().enumerate() {
381 cumulative += prob;
382 if r < cumulative {
383 return i;
384 }
385 }
386
387 probs.len() - 1
388 }
389}
390
391pub mod utils {
393 use super::*;
394
395 pub fn pad_sequences(sequences: &[Vec<usize>], pad_token: usize) -> (Vec<Vec<usize>>, Vec<usize>) {
397 let max_len = sequences.iter().map(|s| s.len()).max().unwrap_or(0);
398 let mut padded = Vec::new();
399 let mut lengths = Vec::new();
400
401 for seq in sequences {
402 lengths.push(seq.len());
403 let mut padded_seq = seq.clone();
404 while padded_seq.len() < max_len {
405 padded_seq.push(pad_token);
406 }
407 padded.push(padded_seq);
408 }
409
410 (padded, lengths)
411 }
412
413 pub fn create_attention_mask(lengths: &[usize]) -> DenseTensor {
415 let batch_size = lengths.len();
416 let max_len = lengths.iter().max().copied().unwrap_or(0);
417
418 let mut data = Vec::with_capacity(batch_size * max_len * max_len);
419
420 for &seq_len in lengths.iter() {
421 for j in 0..max_len {
422 for k in 0..max_len {
423 let can_attend = (j < seq_len && k < seq_len) as u8 as f64;
424 data.push(if can_attend == 1.0 { 0.0 } else { f64::NEG_INFINITY });
425 }
426 }
427 }
428
429 DenseTensor::new(data, vec![batch_size, max_len, max_len])
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use crate::transformer::model::LlamaModel;
437 use crate::transformer::layers::{MultiHeadAttention, FeedForward, RMSNorm};
438 use crate::transformer::loader::LlamaConfig;
439 use crate::tensor::DenseTensor;
440
441 fn create_test_model() -> LlamaModel {
442 let config = LlamaConfig::llama_7b();
443 let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
444
445 let hidden_dim = config.hidden_size;
446 let num_heads = config.num_attention_heads;
447
448 let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
449 let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
450 let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
451 let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
452 let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
453
454 let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
455 let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
456 let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
457 let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
458
459 let input_layernorm = RMSNorm::default(hidden_dim);
460 let post_attention_layernorm = RMSNorm::default(hidden_dim);
461
462 let layer = super::super::model::LlamaDecoderLayer::new(
463 self_attn, mlp, input_layernorm, post_attention_layernorm
464 );
465
466 let layers = vec![layer; 2];
467 let norm = RMSNorm::default(hidden_dim);
468
469 LlamaModel::new(config, embed_tokens, layers, norm, None)
470 }
471
472 #[test]
473 fn test_batch_data_creation() {
474 let input_ids = vec![
475 vec![1, 2, 3],
476 vec![4, 5],
477 vec![6, 7, 8, 9],
478 ];
479
480 let batch = BatchData::new(input_ids.clone());
481
482 assert_eq!(batch.batch_size(), 3);
483 assert_eq!(batch.max_seq_len(), 4);
484 assert_eq!(batch.seq_lengths, vec![3, 2, 4]);
485 }
486
487 #[test]
488 fn test_inference_request() {
489 let config = GenerationConfig::greedy();
490 let mut request = InferenceRequest::new(0, vec![1, 2, 3], config);
491
492 assert!(!request.completed);
493 assert_eq!(request.current_len(), 3);
494
495 request.append_token(4);
496 assert_eq!(request.current_len(), 4);
497 }
498
499 #[test]
500 fn test_request_scheduler() {
501 let mut scheduler = RequestScheduler::new(2);
502
503 let _id1 = scheduler.add_request(vec![1, 2, 3], GenerationConfig::greedy());
504 let _id2 = scheduler.add_request(vec![4, 5], GenerationConfig::greedy());
505 let _id3 = scheduler.add_request(vec![6, 7, 8], GenerationConfig::greedy());
506
507 assert_eq!(scheduler.num_pending(), 3);
508 assert_eq!(scheduler.num_active(), 0);
509
510 let active = scheduler.schedule();
511 assert_eq!(active.len(), 2); assert_eq!(scheduler.num_pending(), 1);
513 assert_eq!(scheduler.num_active(), 2);
514 }
515
516 #[test]
517 fn test_batch_inference_creation() {
518 let model = create_test_model();
519 let batch_infer = BatchInference::new(&model, 4, 512);
520
521 assert_eq!(batch_infer.kv_caches.len(), 4);
522 }
523
524 #[test]
525 fn test_pad_sequences() {
526 let sequences = vec![
527 vec![1, 2],
528 vec![3, 4, 5],
529 vec![6],
530 ];
531
532 let (padded, lengths) = utils::pad_sequences(&sequences, 0);
533
534 assert_eq!(padded, vec![
535 vec![1, 2, 0],
536 vec![3, 4, 5],
537 vec![6, 0, 0],
538 ]);
539 assert_eq!(lengths, vec![2, 3, 1]);
540 }
541}