use crate::{ids::*, models::TokenUsage, FinishReason, Priority, SamplingParams, TokenId};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceRequest {
pub id: RequestId,
pub prompt: String,
pub model_id: ModelId,
pub sampling_params: SamplingParams,
pub stream: bool,
pub priority: Priority,
pub client_id: Option<ClientId>,
pub session_id: Option<SessionId>,
pub created_at: DateTime<Utc>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl InferenceRequest {
pub fn new(prompt: impl Into<String>, model_id: impl Into<ModelId>) -> Self {
Self {
id: RequestId::new(),
prompt: prompt.into(),
model_id: model_id.into(),
sampling_params: SamplingParams::default(),
stream: false,
priority: Priority::default(),
client_id: None,
session_id: None,
created_at: Utc::now(),
metadata: HashMap::new(),
}
}
pub fn with_sampling_params(mut self, params: SamplingParams) -> Self {
self.sampling_params = params;
self
}
pub fn with_stream(mut self, stream: bool) -> Self {
self.stream = stream;
self
}
pub fn with_priority(mut self, priority: Priority) -> Self {
self.priority = priority;
self
}
pub fn with_client_id(mut self, client_id: impl Into<ClientId>) -> Self {
self.client_id = Some(client_id.into());
self
}
pub fn with_session_id(mut self, session_id: SessionId) -> Self {
self.session_id = Some(session_id);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceResponse {
pub request_id: RequestId,
pub text: String,
pub tokens: Vec<TokenId>,
pub finish_reason: FinishReason,
pub usage: TokenUsage,
pub latency_ms: u64,
pub created_at: DateTime<Utc>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamChunk {
pub request_id: RequestId,
pub text: String,
pub token: Option<TokenId>,
pub finish_reason: Option<FinishReason>,
pub usage: Option<TokenUsage>,
pub created_at: DateTime<Utc>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchRequest {
pub batch_id: BatchId,
pub requests: Vec<InferenceRequest>,
pub max_sequence_length: usize,
pub created_at: DateTime<Utc>,
}
impl BatchRequest {
pub fn new(requests: Vec<InferenceRequest>) -> Self {
let max_sequence_length = requests
.iter()
.map(|r| r.sampling_params.max_tokens)
.max()
.unwrap_or(512);
Self {
batch_id: BatchId::new(),
requests,
max_sequence_length,
created_at: Utc::now(),
}
}
pub fn size(&self) -> usize {
self.requests.len()
}
pub fn is_empty(&self) -> bool {
self.requests.is_empty()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RequestState {
Waiting,
Running,
Preempted,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone)]
pub struct ScheduledRequest {
pub request: InferenceRequest,
pub state: RequestState,
pub allocated_blocks: Vec<crate::BlockId>,
pub tokens_processed: usize,
pub estimated_completion: Option<DateTime<Utc>>,
}
impl ScheduledRequest {
pub fn new(request: InferenceRequest) -> Self {
Self {
request,
state: RequestState::Waiting,
allocated_blocks: Vec::new(),
tokens_processed: 0,
estimated_completion: None,
}
}
pub fn set_state(&mut self, state: RequestState) {
self.state = state;
}
pub fn add_blocks(&mut self, blocks: Vec<crate::BlockId>) {
self.allocated_blocks.extend(blocks);
}
pub fn update_progress(&mut self, tokens_processed: usize) {
self.tokens_processed = tokens_processed;
}
}