1use crate::{ids::*, models::TokenUsage, FinishReason, Priority, SamplingParams, TokenId};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct InferenceRequest {
11 pub id: RequestId,
13 pub prompt: String,
15 pub model_id: ModelId,
17 pub sampling_params: SamplingParams,
19 pub stream: bool,
21 pub priority: Priority,
23 pub client_id: Option<ClientId>,
25 pub session_id: Option<SessionId>,
27 pub created_at: DateTime<Utc>,
29 pub metadata: HashMap<String, serde_json::Value>,
31}
32
33impl InferenceRequest {
34 pub fn new(prompt: impl Into<String>, model_id: impl Into<ModelId>) -> Self {
36 Self {
37 id: RequestId::new(),
38 prompt: prompt.into(),
39 model_id: model_id.into(),
40 sampling_params: SamplingParams::default(),
41 stream: false,
42 priority: Priority::default(),
43 client_id: None,
44 session_id: None,
45 created_at: Utc::now(),
46 metadata: HashMap::new(),
47 }
48 }
49
50 pub fn with_sampling_params(mut self, params: SamplingParams) -> Self {
52 self.sampling_params = params;
53 self
54 }
55
56 pub fn with_stream(mut self, stream: bool) -> Self {
58 self.stream = stream;
59 self
60 }
61
62 pub fn with_priority(mut self, priority: Priority) -> Self {
64 self.priority = priority;
65 self
66 }
67
68 pub fn with_client_id(mut self, client_id: impl Into<ClientId>) -> Self {
70 self.client_id = Some(client_id.into());
71 self
72 }
73
74 pub fn with_session_id(mut self, session_id: SessionId) -> Self {
76 self.session_id = Some(session_id);
77 self
78 }
79
80 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
82 self.metadata.insert(key.into(), value);
83 self
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct InferenceResponse {
90 pub request_id: RequestId,
92 pub text: String,
94 pub tokens: Vec<TokenId>,
96 pub finish_reason: FinishReason,
98 pub usage: TokenUsage,
100 pub latency_ms: u64,
102 pub created_at: DateTime<Utc>,
104 pub metadata: HashMap<String, serde_json::Value>,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct StreamChunk {
111 pub request_id: RequestId,
113 pub text: String,
115 pub token: Option<TokenId>,
117 pub finish_reason: Option<FinishReason>,
119 pub usage: Option<TokenUsage>,
121 pub created_at: DateTime<Utc>,
123 pub metadata: HashMap<String, serde_json::Value>,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct BatchRequest {
130 pub batch_id: BatchId,
132 pub requests: Vec<InferenceRequest>,
134 pub max_sequence_length: usize,
136 pub created_at: DateTime<Utc>,
138}
139
140impl BatchRequest {
141 pub fn new(requests: Vec<InferenceRequest>) -> Self {
143 let max_sequence_length = requests
144 .iter()
145 .map(|r| r.sampling_params.max_tokens)
146 .max()
147 .unwrap_or(512);
148
149 Self {
150 batch_id: BatchId::new(),
151 requests,
152 max_sequence_length,
153 created_at: Utc::now(),
154 }
155 }
156
157 pub fn size(&self) -> usize {
159 self.requests.len()
160 }
161
162 pub fn is_empty(&self) -> bool {
164 self.requests.is_empty()
165 }
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
170pub enum RequestState {
171 Waiting,
173 Running,
175 Preempted,
177 Completed,
179 Failed,
181 Cancelled,
183}
184
185#[derive(Debug, Clone)]
187pub struct ScheduledRequest {
188 pub request: InferenceRequest,
190 pub state: RequestState,
192 pub allocated_blocks: Vec<crate::BlockId>,
194 pub tokens_processed: usize,
196 pub estimated_completion: Option<DateTime<Utc>>,
198}
199
200impl ScheduledRequest {
201 pub fn new(request: InferenceRequest) -> Self {
203 Self {
204 request,
205 state: RequestState::Waiting,
206 allocated_blocks: Vec::new(),
207 tokens_processed: 0,
208 estimated_completion: None,
209 }
210 }
211
212 pub fn set_state(&mut self, state: RequestState) {
214 self.state = state;
215 }
216
217 pub fn add_blocks(&mut self, blocks: Vec<crate::BlockId>) {
219 self.allocated_blocks.extend(blocks);
220 }
221
222 pub fn update_progress(&mut self, tokens_processed: usize) {
224 self.tokens_processed = tokens_processed;
225 }
226}