use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
pub type RequestId = u64;
#[derive(Debug)]
pub struct GenerationRequest {
pub id: RequestId,
pub input_tokens: Vec<u32>,
pub max_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub stop_sequences: Vec<Vec<u32>>,
pub token_sender: mpsc::Sender<GenerationEvent>,
}
#[derive(Debug, Clone)]
pub enum GenerationEvent {
Token(u32),
Finished { reason: FinishReason },
Error(String),
}
#[derive(Debug, Clone)]
pub enum FinishReason {
MaxTokens,
EndOfSequence,
StopSequence,
Cancelled,
}
#[derive(Debug)]
struct Sequence {
request_id: RequestId,
tokens: Vec<u32>,
prompt_len: usize,
generated: usize,
max_tokens: usize,
temperature: f32,
top_p: f32,
stop_sequences: Vec<Vec<u32>>,
token_sender: mpsc::Sender<GenerationEvent>,
cache_position: usize,
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub max_batch_size: usize,
pub max_seq_len: usize,
pub batch_timeout_ms: u64,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_batch_size: 8,
max_seq_len: 4096,
batch_timeout_ms: 10,
}
}
}
pub struct BatchScheduler {
config: BatchConfig,
sequences: HashMap<RequestId, Sequence>,
next_id: RequestId,
pending: Vec<GenerationRequest>,
}
impl BatchScheduler {
pub fn new(config: BatchConfig) -> Self {
Self {
config,
sequences: HashMap::new(),
next_id: 1,
pending: Vec::new(),
}
}
pub fn add_request(&mut self, mut request: GenerationRequest) -> RequestId {
let id = self.next_id;
self.next_id += 1;
request.id = id;
if self.sequences.len() < self.config.max_batch_size {
self.add_to_batch(request);
} else {
self.pending.push(request);
}
id
}
fn add_to_batch(&mut self, request: GenerationRequest) {
let seq = Sequence {
request_id: request.id,
tokens: request.input_tokens.clone(),
prompt_len: request.input_tokens.len(),
generated: 0,
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p,
stop_sequences: request.stop_sequences,
token_sender: request.token_sender,
cache_position: 0,
};
self.sequences.insert(request.id, seq);
}
pub fn cancel_request(&mut self, id: RequestId) {
if let Some(seq) = self.sequences.remove(&id) {
let _ = seq.token_sender.try_send(GenerationEvent::Finished {
reason: FinishReason::Cancelled,
});
}
self.pending.retain(|r| r.id != id);
self.promote_pending();
}
fn promote_pending(&mut self) {
while self.sequences.len() < self.config.max_batch_size && !self.pending.is_empty() {
let request = self.pending.remove(0);
self.add_to_batch(request);
}
}
pub fn get_batch(&self) -> Vec<RequestId> {
self.sequences.keys().copied().collect()
}
pub fn get_sequence_tokens(&self, id: RequestId) -> Option<&[u32]> {
self.sequences.get(&id).map(|s| s.tokens.as_slice())
}
pub fn get_sequence_info(&self, id: RequestId) -> Option<(f32, f32)> {
self.sequences.get(&id).map(|s| (s.temperature, s.top_p))
}
pub fn process_token(&mut self, id: RequestId, token: u32, eos_token: u32) -> bool {
let (should_stop, finish_reason) = {
let seq = match self.sequences.get_mut(&id) {
Some(s) => s,
None => return false,
};
seq.tokens.push(token);
seq.generated += 1;
let _ = seq.token_sender.try_send(GenerationEvent::Token(token));
if token == eos_token {
(true, Some(FinishReason::EndOfSequence))
} else if seq.generated >= seq.max_tokens {
(true, Some(FinishReason::MaxTokens))
} else if Self::check_stop_sequence_static(&seq.tokens, &seq.stop_sequences) {
(true, Some(FinishReason::StopSequence))
} else {
(false, None)
}
};
if should_stop {
if let Some(seq) = self.sequences.remove(&id)
&& let Some(reason) = finish_reason
{
let _ = seq
.token_sender
.try_send(GenerationEvent::Finished { reason });
}
self.promote_pending();
}
!should_stop
}
fn check_stop_sequence_static(tokens: &[u32], stop_sequences: &[Vec<u32>]) -> bool {
for stop_seq in stop_sequences {
if tokens.len() >= stop_seq.len() {
let end = &tokens[tokens.len() - stop_seq.len()..];
if end == stop_seq {
return true;
}
}
}
false
}
pub fn report_error(&mut self, id: RequestId, error: String) {
if let Some(seq) = self.sequences.remove(&id) {
let _ = seq.token_sender.try_send(GenerationEvent::Error(error));
}
self.promote_pending();
}
pub fn active_count(&self) -> usize {
self.sequences.len()
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn has_work(&self) -> bool {
!self.sequences.is_empty()
}
}
pub type SharedBatchScheduler = Arc<Mutex<BatchScheduler>>;
pub fn new_batch_scheduler(config: BatchConfig) -> SharedBatchScheduler {
Arc::new(Mutex::new(BatchScheduler::new(config)))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_batch_scheduler_basic() {
let scheduler = BatchScheduler::new(BatchConfig::default());
assert_eq!(scheduler.active_count(), 0);
assert_eq!(scheduler.pending_count(), 0);
}
#[tokio::test]
async fn test_batch_scheduler_add_request() {
let mut scheduler = BatchScheduler::new(BatchConfig {
max_batch_size: 2,
..Default::default()
});
let (tx, _rx) = mpsc::channel(100);
let request = GenerationRequest {
id: 0,
input_tokens: vec![1, 2, 3],
max_tokens: 10,
temperature: 0.8,
top_p: 0.9,
stop_sequences: vec![],
token_sender: tx,
};
let id = scheduler.add_request(request);
assert_eq!(id, 1);
assert_eq!(scheduler.active_count(), 1);
}
#[tokio::test]
async fn test_batch_scheduler_overflow() {
let mut scheduler = BatchScheduler::new(BatchConfig {
max_batch_size: 1,
..Default::default()
});
for _ in 0..3 {
let (tx, _rx) = mpsc::channel(100);
let request = GenerationRequest {
id: 0,
input_tokens: vec![1, 2, 3],
max_tokens: 10,
temperature: 0.8,
top_p: 0.9,
stop_sequences: vec![],
token_sender: tx,
};
scheduler.add_request(request);
}
assert_eq!(scheduler.active_count(), 1);
assert_eq!(scheduler.pending_count(), 2);
}
}