#![allow(clippy::many_single_char_names)]
#![allow(clippy::similar_names)]
#[cfg(feature = "gpu")]
use super::runtime::OwnedQuantizedKVCache;
#[derive(Debug, Clone)]
pub struct BatchGenerationStats {
pub gpu_cache_ready: bool,
pub cache_memory_gb: f64,
pub num_layers: usize,
pub hidden_dim: usize,
pub intermediate_dim: usize,
pub recommended_batch_size: usize,
pub max_batch_size: usize,
}
#[cfg(feature = "gpu")]
#[derive(Debug, Clone)]
pub struct PendingRequest {
pub id: u64,
pub prompt: Vec<u32>,
pub max_tokens: usize,
pub temperature: f32,
pub top_k: usize,
pub submitted_at: std::time::Instant,
}
#[cfg(feature = "gpu")]
impl PendingRequest {
pub fn new(
id: u64,
prompt: Vec<u32>,
max_tokens: usize,
temperature: f32,
top_k: usize,
) -> Self {
Self {
id,
prompt,
max_tokens,
temperature,
top_k,
submitted_at: std::time::Instant::now(),
}
}
pub fn wait_time(&self) -> std::time::Duration {
self.submitted_at.elapsed()
}
}
#[cfg(feature = "gpu")]
#[derive(Debug)]
pub struct RequestBatch {
pub requests: Vec<PendingRequest>,
pub formed_at: std::time::Instant,
}
#[cfg(feature = "gpu")]
impl RequestBatch {
pub fn new(requests: Vec<PendingRequest>) -> Self {
Self {
requests,
formed_at: std::time::Instant::now(),
}
}
pub fn size(&self) -> usize {
self.requests.len()
}
pub fn prompts(&self) -> Vec<Vec<u32>> {
self.requests.iter().map(|r| r.prompt.clone()).collect()
}
pub fn avg_wait_time(&self) -> std::time::Duration {
if self.requests.is_empty() {
return std::time::Duration::ZERO;
}
let total: std::time::Duration = self.requests.iter().map(PendingRequest::wait_time).sum();
total / self.requests.len() as u32
}
}
#[cfg(feature = "gpu")]
pub struct BatchRequestCollector {
pending: std::sync::Mutex<Vec<PendingRequest>>,
next_id: std::sync::atomic::AtomicU64,
pub batch_threshold: usize,
pub timeout_ms: u64,
pub max_batch_size: usize,
}
#[cfg(feature = "gpu")]
impl BatchRequestCollector {
pub fn new() -> Self {
Self {
pending: std::sync::Mutex::new(Vec::new()),
next_id: std::sync::atomic::AtomicU64::new(0),
batch_threshold: 32,
timeout_ms: 50,
max_batch_size: 64,
}
}
pub fn with_thresholds(batch_threshold: usize, timeout_ms: u64, max_batch_size: usize) -> Self {
Self {
pending: std::sync::Mutex::new(Vec::new()),
next_id: std::sync::atomic::AtomicU64::new(0),
batch_threshold,
timeout_ms,
max_batch_size,
}
}
pub fn submit(
&self,
prompt: Vec<u32>,
max_tokens: usize,
temperature: f32,
top_k: usize,
) -> u64 {
let id = self
.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let request = PendingRequest::new(id, prompt, max_tokens, temperature, top_k);
let mut pending = self.pending.lock().expect("Mutex poisoned");
pending.push(request);
id
}
pub fn is_batch_ready(&self) -> bool {
let pending = self.pending.lock().expect("Mutex poisoned");
if pending.is_empty() {
return false;
}
if pending.len() >= self.batch_threshold {
return true;
}
if let Some(oldest) = pending.first() {
let wait_ms = oldest.wait_time().as_millis() as u64;
if wait_ms >= self.timeout_ms {
return true;
}
}
false
}
pub fn collect_batch(&self) -> Option<RequestBatch> {
let mut pending = self.pending.lock().expect("Mutex poisoned");
if pending.is_empty() {
return None;
}
let ready = pending.len() >= self.batch_threshold
|| pending
.first()
.is_some_and(|r| r.wait_time().as_millis() as u64 >= self.timeout_ms);
if !ready {
return None;
}
let batch_size = pending.len().min(self.max_batch_size);
let requests: Vec<PendingRequest> = pending.drain(..batch_size).collect();
Some(RequestBatch::new(requests))
}
pub fn flush(&self) -> Option<RequestBatch> {
let mut pending = self.pending.lock().expect("Mutex poisoned");
if pending.is_empty() {
return None;
}
let requests: Vec<PendingRequest> = pending.drain(..).collect();
Some(RequestBatch::new(requests))
}
pub fn pending_count(&self) -> usize {
self.pending.lock().expect("Mutex poisoned").len()
}
pub fn total_submitted(&self) -> u64 {
self.next_id.load(std::sync::atomic::Ordering::Relaxed)
}
}
#[cfg(feature = "gpu")]
impl Default for BatchRequestCollector {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "gpu")]
#[derive(Debug, Clone)]
pub struct BatchingConfig {
pub batch_threshold: usize,
pub timeout_ms: u64,
pub max_batch_size: usize,
pub prefer_throughput: bool,
}
#[cfg(feature = "gpu")]
impl Default for BatchingConfig {
fn default() -> Self {
Self {
batch_threshold: 32,
timeout_ms: 50,
max_batch_size: 64,
prefer_throughput: true,
}
}
}
#[cfg(feature = "gpu")]
impl BatchingConfig {
pub fn latency_optimized() -> Self {
Self {
batch_threshold: 8,
timeout_ms: 10,
max_batch_size: 32,
prefer_throughput: false,
}
}
pub fn throughput_optimized() -> Self {
Self {
batch_threshold: 32,
timeout_ms: 100,
max_batch_size: 64,
prefer_throughput: true,
}
}
}
#[cfg(feature = "gpu")]
#[derive(Debug, Clone)]
pub enum SlotState {
Empty,
Active {
request_id: u64,
prompt_tokens: Vec<u32>,
generated_tokens: Vec<u32>,
max_tokens: usize,
temperature: f32,
top_k: usize,
},
Completed {
request_id: u64,
generated_tokens: Vec<u32>,
},
}
#[cfg(feature = "gpu")]
impl SlotState {
pub fn is_empty(&self) -> bool {
matches!(self, Self::Empty)
}
pub fn is_active(&self) -> bool {
matches!(self, Self::Active { .. })
}
pub fn is_completed(&self) -> bool {
matches!(self, Self::Completed { .. })
}
pub fn request_id(&self) -> Option<u64> {
match self {
Self::Empty => None,
Self::Active { request_id, .. } | Self::Completed { request_id, .. } => {
Some(*request_id)
},
}
}
}
#[cfg(feature = "gpu")]
pub struct ContinuousBatchScheduler {
slots: std::sync::Mutex<Vec<SlotState>>,
caches: std::sync::Mutex<Vec<OwnedQuantizedKVCache>>,
pub num_slots: usize,
completed: std::sync::Mutex<Vec<(u64, Vec<u32>)>>,
next_id: std::sync::atomic::AtomicU64,
}
include!("batch_scheduler_lock_slots_completed.rs");
include!("batch_scheduler_counter_inc_load.rs");
include!("batch_scheduler_prefix_cache.rs");
include!("batch_scheduler_multi_request.rs");