#![allow(unreachable_pub)]
use std::convert::Infallible;
use axum::{
extract::State,
http::StatusCode,
response::sse::{Event, Sse},
Json,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use super::{
default_max_tokens, default_top_k, AppState, BatchGenerateRequest, BatchGenerateResponse,
BatchTokenizeRequest, BatchTokenizeResponse, ErrorResponse, GenerateRequest, GenerateResponse,
ModelsResponse, StreamDoneEvent, StreamTokenEvent, TokenizeRequest, TokenizeResponse,
};
use crate::generate::{GenerationConfig, SamplingStrategy};
use crate::registry::ModelInfo;
use crate::tokenizer::BPETokenizer;
type ApiErr = (StatusCode, Json<ErrorResponse>);
fn api_err(status: StatusCode, msg: impl std::fmt::Display) -> ApiErr {
(
status,
Json(ErrorResponse {
error: msg.to_string(),
}),
)
}
fn require_tok(state: &AppState) -> Result<std::sync::Arc<BPETokenizer>, ApiErr> {
state
.tokenizer
.clone()
.ok_or_else(|| api_err(StatusCode::INTERNAL_SERVER_ERROR, "No tokenizer available"))
}
fn tokenize_prompt(tokenizer: &BPETokenizer, prompt: &str) -> Result<Vec<u32>, ApiErr> {
let ids = tokenizer.encode(prompt);
if ids.is_empty() {
return Err(api_err(StatusCode::BAD_REQUEST, "Prompt cannot be empty"));
}
Ok(ids)
}
fn eos_id(tokenizer: &BPETokenizer, model_eos: Option<u32>) -> u32 {
model_eos
.or_else(|| tokenizer.get_token_id("<|im_end|>"))
.or_else(|| tokenizer.get_token_id("<|endoftext|>"))
.unwrap_or(0)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuBatchRequest {
pub prompts: Vec<String>,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
#[serde(default)]
pub temperature: f32,
#[serde(default = "default_top_k")]
pub top_k: usize,
#[serde(default)]
pub stop: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuBatchResponse {
pub results: Vec<GpuBatchResult>,
pub stats: GpuBatchStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuBatchResult {
pub index: usize,
pub token_ids: Vec<u32>,
pub text: String,
pub num_generated: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuBatchStats {
pub batch_size: usize,
pub gpu_used: bool,
pub total_tokens: usize,
pub processing_time_ms: f64,
pub throughput_tps: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuWarmupResponse {
pub success: bool,
pub memory_bytes: usize,
pub num_layers: usize,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpuStatusResponse {
pub cache_ready: bool,
pub cache_memory_bytes: usize,
pub batch_threshold: usize,
pub recommended_min_batch: usize,
}
#[derive(Debug, Clone)]
#[cfg(feature = "gpu")]
pub struct BatchConfig {
pub window_ms: u64,
pub min_batch: usize,
pub optimal_batch: usize,
pub max_batch: usize,
pub queue_size: usize,
pub gpu_threshold: usize,
}
#[cfg(feature = "gpu")]
impl Default for BatchConfig {
fn default() -> Self {
Self {
window_ms: 50, min_batch: 4, optimal_batch: 32, max_batch: 64, queue_size: 1024, gpu_threshold: 32, }
}
}
#[cfg(feature = "gpu")]
impl BatchConfig {
pub fn low_latency() -> Self {
Self {
window_ms: 5,
min_batch: 2,
optimal_batch: 8,
max_batch: 16,
queue_size: 512,
gpu_threshold: 32, }
}
pub fn high_throughput() -> Self {
Self {
window_ms: 100, min_batch: 8,
optimal_batch: 32, max_batch: 128, queue_size: 2048,
gpu_threshold: 32, }
}
pub fn should_process(&self, batch_size: usize) -> bool {
batch_size >= self.optimal_batch
}
pub fn meets_minimum(&self, batch_size: usize) -> bool {
batch_size >= self.min_batch
}
}
#[cfg(feature = "gpu")]
pub struct ContinuousBatchRequest {
pub prompt_tokens: Vec<u32>,
pub max_tokens: usize,
pub temperature: f32,
pub top_k: usize,
pub response_tx: tokio::sync::oneshot::Sender<ContinuousBatchResponse>,
pub submitted_at: std::time::Instant,
}
#[cfg(feature = "gpu")]
#[derive(Debug, Clone)]
pub struct ContinuousBatchResponse {
pub token_ids: Vec<u32>,
pub prompt_len: usize,
pub batched: bool,
pub batch_size: usize,
pub latency_ms: f64,
}
#[cfg(feature = "gpu")]
impl ContinuousBatchResponse {
pub fn single(token_ids: Vec<u32>, prompt_len: usize, latency_ms: f64) -> Self {
Self {
token_ids,
prompt_len,
batched: false,
batch_size: 1,
latency_ms,
}
}
pub fn batched(
token_ids: Vec<u32>,
prompt_len: usize,
batch_size: usize,
latency_ms: f64,
) -> Self {
Self {
token_ids,
prompt_len,
batched: true,
batch_size,
latency_ms,
}
}
pub fn generated_tokens(&self) -> &[u32] {
if self.token_ids.len() > self.prompt_len {
&self.token_ids[self.prompt_len..]
} else {
&[]
}
}
}
#[derive(Debug, Clone, Default)]
#[cfg(feature = "gpu")]
pub struct BatchQueueStats {
pub total_queued: u64,
pub total_batches: u64,
pub total_single: u64,
pub avg_batch_size: f64,
pub avg_wait_ms: f64,
}
#[cfg(feature = "gpu")]
#[derive(Debug)]
pub struct BatchProcessResult {
pub requests_processed: usize,
pub was_batched: bool,
pub total_time_ms: f64,
pub avg_latency_ms: f64,
}
#[cfg(feature = "gpu")]
pub fn spawn_batch_processor(
model: std::sync::Arc<crate::gguf::OwnedQuantizedModelCachedSync>,
config: BatchConfig,
) -> tokio::sync::mpsc::Sender<ContinuousBatchRequest> {
let (tx, rx) = tokio::sync::mpsc::channel(config.queue_size);
tokio::spawn(batch_processor_task(rx, model, config));
tx
}
#[cfg(feature = "gpu")]
async fn batch_processor_task(
mut rx: tokio::sync::mpsc::Receiver<ContinuousBatchRequest>,
model: std::sync::Arc<crate::gguf::OwnedQuantizedModelCachedSync>,
config: BatchConfig,
) {
use std::time::{Duration, Instant};
use tokio::time::timeout;
let mut batch: Vec<ContinuousBatchRequest> = Vec::with_capacity(config.max_batch);
let mut window_start = Instant::now();
loop {
let elapsed = window_start.elapsed();
let remaining = Duration::from_millis(config.window_ms).saturating_sub(elapsed);
match timeout(remaining, rx.recv()).await {
Ok(Some(request)) => {
batch.push(request);
if batch.len() >= config.optimal_batch {
process_batch(&model, &config, &mut batch).await;
window_start = Instant::now();
}
},
Ok(None) => {
if !batch.is_empty() {
process_batch(&model, &config, &mut batch).await;
}
break;
},
Err(_) => {
if !batch.is_empty() {
process_batch(&model, &config, &mut batch).await;
}
window_start = Instant::now();
},
}
}
}
include!("batch_processing.rs");
include!("batch.rs");
include!("stream_generate.rs");