use crate::backends::GenerateParams;
use serde::{Deserialize, Serialize};
use std::time::Instant;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct RequestId(pub Uuid);
impl RequestId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
pub fn from_uuid(uuid: Uuid) -> Self {
Self(uuid)
}
}
impl Default for RequestId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum Priority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
impl Default for Priority {
fn default() -> Self {
Self::Normal
}
}
impl Priority {
pub fn value(&self) -> u8 {
*self as u8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RequestState {
Pending,
Running,
Preempted,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone)]
pub struct InferenceRequest {
pub id: RequestId,
pub prompt_tokens: Vec<u32>,
pub params: GenerateParams,
pub arrival_time: Instant,
pub priority: Priority,
pub session_id: Option<String>,
pub max_seq_len: usize,
pub metadata: Option<serde_json::Value>,
}
impl InferenceRequest {
pub fn new(prompt_tokens: Vec<u32>, params: GenerateParams) -> Self {
let max_seq_len = prompt_tokens.len() + params.max_tokens;
Self {
id: RequestId::new(),
prompt_tokens,
params,
arrival_time: Instant::now(),
priority: Priority::Normal,
session_id: None,
max_seq_len,
metadata: None,
}
}
pub fn with_priority(mut self, priority: Priority) -> Self {
self.priority = priority;
self
}
pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn prompt_len(&self) -> usize {
self.prompt_tokens.len()
}
pub fn max_new_tokens(&self) -> usize {
self.params.max_tokens
}
pub fn waiting_time(&self) -> std::time::Duration {
self.arrival_time.elapsed()
}
}
#[derive(Debug)]
pub struct RunningRequest {
pub request: InferenceRequest,
pub generated_tokens: Vec<u32>,
pub kv_cache_slot: usize,
pub current_seq_len: usize,
pub prefill_tokens_processed: usize,
pub prefill_complete: bool,
pub start_time: Instant,
pub last_step_time: Instant,
pub decode_steps: usize,
pub state: RequestState,
pub block_table: Vec<usize>,
pub context_len: usize,
}
impl RunningRequest {
pub fn new(request: InferenceRequest, kv_cache_slot: usize) -> Self {
let now = Instant::now();
let prompt_len = request.prompt_tokens.len();
Self {
request,
generated_tokens: Vec::new(),
kv_cache_slot,
current_seq_len: prompt_len,
prefill_tokens_processed: 0,
prefill_complete: false,
start_time: now,
last_step_time: now,
decode_steps: 0,
state: RequestState::Running,
block_table: Vec::new(),
context_len: 0,
}
}
pub fn id(&self) -> RequestId {
self.request.id
}
pub fn add_token(&mut self, token: u32) {
self.generated_tokens.push(token);
self.current_seq_len += 1;
self.decode_steps += 1;
self.last_step_time = Instant::now();
}
pub fn is_complete(&self) -> bool {
if self.generated_tokens.len() >= self.request.params.max_tokens {
return true;
}
false
}
pub fn should_stop(&self, _decoded_text: &str) -> bool {
self.is_complete()
}
pub fn total_tokens(&self) -> usize {
self.current_seq_len
}
pub fn remaining_tokens(&self) -> usize {
self.request
.params
.max_tokens
.saturating_sub(self.generated_tokens.len())
}
pub fn next_position(&self) -> usize {
self.current_seq_len
}
pub fn processing_time(&self) -> std::time::Duration {
self.start_time.elapsed()
}
pub fn time_since_last_step(&self) -> std::time::Duration {
self.last_step_time.elapsed()
}
pub fn tokens_per_second(&self) -> f64 {
let elapsed = self.processing_time().as_secs_f64();
if elapsed > 0.0 && self.decode_steps > 0 {
self.decode_steps as f64 / elapsed
} else {
0.0
}
}
pub fn complete_prefill(&mut self) {
self.prefill_complete = true;
self.prefill_tokens_processed = self.request.prompt_tokens.len();
self.context_len = self.prefill_tokens_processed;
}
pub fn get_prefill_tokens(&self) -> &[u32] {
&self.request.prompt_tokens[self.prefill_tokens_processed..]
}
pub fn advance_prefill(&mut self, count: usize) {
self.prefill_tokens_processed += count;
self.context_len = self.prefill_tokens_processed;
if self.prefill_tokens_processed >= self.request.prompt_tokens.len() {
self.prefill_complete = true;
}
}
}
#[derive(Debug, Clone)]
pub struct CompletedRequest {
pub id: RequestId,
pub prompt_tokens: Vec<u32>,
pub generated_tokens: Vec<u32>,
pub state: RequestState,
pub processing_time_ms: u64,
pub waiting_time_ms: u64,
pub prefill_time_ms: u64,
pub decode_time_ms: u64,
pub decode_steps: usize,
pub tokens_per_second: f64,
pub error: Option<String>,
pub finish_reason: FinishReason,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum FinishReason {
Length,
Stop,
EndOfSequence,
Cancelled,
Error,
}
impl CompletedRequest {
pub fn success(running: &RunningRequest, prefill_time_ms: u64) -> Self {
let processing_time = running.processing_time();
let decode_time_ms = processing_time.as_millis() as u64 - prefill_time_ms;
Self {
id: running.id(),
prompt_tokens: running.request.prompt_tokens.clone(),
generated_tokens: running.generated_tokens.clone(),
state: RequestState::Completed,
processing_time_ms: processing_time.as_millis() as u64,
waiting_time_ms: running.request.waiting_time().as_millis() as u64,
prefill_time_ms,
decode_time_ms,
decode_steps: running.decode_steps,
tokens_per_second: running.tokens_per_second(),
error: None,
finish_reason: if running.generated_tokens.len() >= running.request.params.max_tokens {
FinishReason::Length
} else {
FinishReason::EndOfSequence
},
}
}
pub fn failure(running: &RunningRequest, error: impl Into<String>) -> Self {
Self {
id: running.id(),
prompt_tokens: running.request.prompt_tokens.clone(),
generated_tokens: running.generated_tokens.clone(),
state: RequestState::Failed,
processing_time_ms: running.processing_time().as_millis() as u64,
waiting_time_ms: running.request.waiting_time().as_millis() as u64,
prefill_time_ms: 0,
decode_time_ms: 0,
decode_steps: running.decode_steps,
tokens_per_second: running.tokens_per_second(),
error: Some(error.into()),
finish_reason: FinishReason::Error,
}
}
pub fn cancelled(running: &RunningRequest) -> Self {
Self {
id: running.id(),
prompt_tokens: running.request.prompt_tokens.clone(),
generated_tokens: running.generated_tokens.clone(),
state: RequestState::Cancelled,
processing_time_ms: running.processing_time().as_millis() as u64,
waiting_time_ms: running.request.waiting_time().as_millis() as u64,
prefill_time_ms: 0,
decode_time_ms: 0,
decode_steps: running.decode_steps,
tokens_per_second: running.tokens_per_second(),
error: None,
finish_reason: FinishReason::Cancelled,
}
}
pub fn total_tokens(&self) -> usize {
self.prompt_tokens.len() + self.generated_tokens.len()
}
}
#[derive(Debug, Clone)]
pub struct TokenOutput {
pub request_id: RequestId,
pub token_id: u32,
pub token_text: Option<String>,
pub logprob: Option<f32>,
pub is_final: bool,
pub finish_reason: Option<FinishReason>,
pub seq_len: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_id() {
let id1 = RequestId::new();
let id2 = RequestId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_priority_ordering() {
assert!(Priority::Low < Priority::Normal);
assert!(Priority::Normal < Priority::High);
assert!(Priority::High < Priority::Critical);
}
#[test]
fn test_inference_request() {
let params = GenerateParams::default();
let request = InferenceRequest::new(vec![1, 2, 3], params)
.with_priority(Priority::High)
.with_session("session-123");
assert_eq!(request.prompt_len(), 3);
assert_eq!(request.priority, Priority::High);
assert_eq!(request.session_id, Some("session-123".to_string()));
}
#[test]
fn test_running_request() {
let params = GenerateParams::default().with_max_tokens(10);
let request = InferenceRequest::new(vec![1, 2, 3], params);
let mut running = RunningRequest::new(request, 0);
assert!(!running.is_complete());
assert!(!running.prefill_complete);
running.complete_prefill();
assert!(running.prefill_complete);
for i in 0..10 {
running.add_token(i);
}
assert!(running.is_complete());
}
}