1use crate::{ids::*, models::TokenUsage, FinishReason, Priority, SamplingParams, TokenId};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8pub const PROMPT_TOKENS_METADATA_KEY: &str = "ferrum_prompt_tokens";
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct InferenceRequest {
13 pub id: RequestId,
15 pub prompt: String,
17 pub model_id: ModelId,
19 pub sampling_params: SamplingParams,
21 pub stream: bool,
23 pub priority: Priority,
25 pub client_id: Option<ClientId>,
27 pub session_id: Option<SessionId>,
29 pub created_at: DateTime<Utc>,
31 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub api_request: Option<ApiRequest>,
36 pub metadata: HashMap<String, serde_json::Value>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41#[serde(tag = "kind", rename_all = "snake_case")]
42pub enum ApiRequest {
43 Chat(ApiChatRequest),
44 Completion(ApiCompletionRequest),
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
48#[serde(tag = "kind", rename_all = "snake_case")]
49pub enum ApiResponse {
50 Chat(ApiChatResponse),
51 Completion(ApiCompletionResponse),
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
55pub struct ApiChatRequest {
56 pub messages: Vec<ApiChatMessage>,
57 #[serde(default, skip_serializing_if = "Vec::is_empty")]
58 pub tools: Vec<ApiTool>,
59 #[serde(default, skip_serializing_if = "Option::is_none")]
60 pub tool_choice: Option<ApiToolChoice>,
61 #[serde(default, skip_serializing_if = "Vec::is_empty")]
62 pub legacy_functions: Vec<ApiFunction>,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub legacy_function_call: Option<ApiFunctionCallChoice>,
65 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub response_format: Option<ApiResponseFormat>,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub stream_options: Option<ApiStreamOptions>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub struct ApiCompletionRequest {
73 pub prompt: String,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub response_format: Option<ApiResponseFormat>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79pub struct ApiChatResponse {
80 pub message: ApiChatMessage,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub finish_reason: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub struct ApiCompletionResponse {
87 pub text: String,
88 #[serde(default, skip_serializing_if = "Option::is_none")]
89 pub finish_reason: Option<String>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93pub struct ApiChatMessage {
94 pub role: ApiMessageRole,
95 pub content: String,
96 #[serde(default, skip_serializing_if = "Option::is_none")]
97 pub name: Option<String>,
98 #[serde(default, skip_serializing_if = "Vec::is_empty")]
99 pub tool_calls: Vec<ApiToolCall>,
100 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub tool_call_id: Option<String>,
102 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub function_call: Option<ApiFunctionCall>,
104}
105
106#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
107#[serde(rename_all = "lowercase")]
108pub enum ApiMessageRole {
109 System,
110 User,
111 Assistant,
112 Function,
113 Tool,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
117pub struct ApiTool {
118 #[serde(rename = "type")]
119 pub tool_type: String,
120 pub function: ApiFunction,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
124pub struct ApiFunction {
125 pub name: String,
126 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub description: Option<String>,
128 #[serde(default, skip_serializing_if = "Option::is_none")]
129 pub parameters: Option<serde_json::Value>,
130 #[serde(default, skip_serializing_if = "Option::is_none")]
131 pub strict: Option<bool>,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
135#[serde(untagged)]
136pub enum ApiToolChoice {
137 Mode(String),
138 Function {
139 #[serde(rename = "type")]
140 tool_type: String,
141 function: ApiToolChoiceFunction,
142 },
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
146pub struct ApiToolChoiceFunction {
147 pub name: String,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
151#[serde(untagged)]
152pub enum ApiFunctionCallChoice {
153 Mode(String),
154 Function { name: String },
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
158pub struct ApiToolCall {
159 pub id: String,
160 #[serde(rename = "type")]
161 pub tool_type: String,
162 pub function: ApiFunctionCall,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
166pub struct ApiFunctionCall {
167 pub name: String,
168 pub arguments: String,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
172pub struct ApiResponseFormat {
173 #[serde(rename = "type")]
174 pub format_type: String,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
176 pub json_schema: Option<ApiJsonSchema>,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
180pub struct ApiJsonSchema {
181 #[serde(default, skip_serializing_if = "Option::is_none")]
182 pub name: Option<String>,
183 pub schema: serde_json::Value,
184 #[serde(default, skip_serializing_if = "Option::is_none")]
185 pub strict: Option<bool>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
189pub struct ApiStreamOptions {
190 #[serde(default, skip_serializing_if = "Option::is_none")]
191 pub include_usage: Option<bool>,
192}
193
194pub fn api_response_from_generated_text(
195 request: &InferenceRequest,
196 text: &str,
197) -> Option<ApiResponse> {
198 let ApiRequest::Chat(chat_request) = request.api_request.as_ref()? else {
199 return None;
200 };
201 chat_api_response_from_generated_text(chat_request, text).map(ApiResponse::Chat)
202}
203
204pub fn chat_api_may_emit_tool_or_function_call(chat_request: &ApiChatRequest) -> bool {
205 (!chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request))
206 || (!chat_request.legacy_functions.is_empty()
207 && !api_function_call_choice_is_none(chat_request))
208}
209
210pub fn chat_api_response_from_generated_text(
211 chat_request: &ApiChatRequest,
212 text: &str,
213) -> Option<ApiChatResponse> {
214 if !chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request) {
215 if let Some(tool_calls) = parse_tool_calls_from_generated_text(text, chat_request) {
216 return Some(ApiChatResponse {
217 message: ApiChatMessage {
218 role: ApiMessageRole::Assistant,
219 content: String::new(),
220 name: None,
221 tool_calls,
222 tool_call_id: None,
223 function_call: None,
224 },
225 finish_reason: Some("tool_calls".to_string()),
226 });
227 }
228 }
229
230 if !chat_request.legacy_functions.is_empty() && !api_function_call_choice_is_none(chat_request)
231 {
232 if let Some(function_call) =
233 parse_legacy_function_call_from_generated_text(text, chat_request)
234 {
235 return Some(ApiChatResponse {
236 message: ApiChatMessage {
237 role: ApiMessageRole::Assistant,
238 content: String::new(),
239 name: None,
240 tool_calls: Vec::new(),
241 tool_call_id: None,
242 function_call: Some(function_call),
243 },
244 finish_reason: Some("function_call".to_string()),
245 });
246 }
247 }
248
249 None
250}
251
252fn api_tool_choice_is_none(chat_request: &ApiChatRequest) -> bool {
253 matches!(
254 chat_request.tool_choice.as_ref(),
255 Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
256 )
257}
258
259fn api_function_call_choice_is_none(chat_request: &ApiChatRequest) -> bool {
260 matches!(
261 chat_request.legacy_function_call.as_ref(),
262 Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
263 )
264}
265
266fn parse_tool_calls_from_generated_text(
267 text: &str,
268 chat_request: &ApiChatRequest,
269) -> Option<Vec<ApiToolCall>> {
270 let value = parse_json_value_from_generated_text(text)?;
271 if let Some(calls) = value.get("tool_calls").and_then(|value| value.as_array()) {
272 let parsed = calls
273 .iter()
274 .enumerate()
275 .filter_map(|(index, value)| parse_tool_call_value(value, index, chat_request))
276 .collect::<Vec<_>>();
277 return (!parsed.is_empty()).then_some(parsed);
278 }
279 if let Some(tool_call) = value.get("tool_call") {
280 return parse_tool_call_value(tool_call, 0, chat_request).map(|call| vec![call]);
281 }
282 parse_tool_call_value(&value, 0, chat_request).map(|call| vec![call])
283}
284
285fn parse_tool_call_value(
286 value: &serde_json::Value,
287 index: usize,
288 chat_request: &ApiChatRequest,
289) -> Option<ApiToolCall> {
290 let tool_type = value
291 .get("type")
292 .and_then(|value| value.as_str())
293 .unwrap_or("function");
294 if tool_type != "function" {
295 return None;
296 }
297 let function = value.get("function").unwrap_or(value);
298 let name = function.get("name").and_then(|value| value.as_str())?;
299 if !api_tool_name_allowed(chat_request, name) {
300 return None;
301 }
302 let arguments = api_arguments_to_string(function.get("arguments"));
303 let id = value
304 .get("id")
305 .and_then(|value| value.as_str())
306 .map(str::to_string)
307 .unwrap_or_else(|| format!("call_{index}"));
308
309 Some(ApiToolCall {
310 id,
311 tool_type: "function".to_string(),
312 function: ApiFunctionCall {
313 name: name.to_string(),
314 arguments,
315 },
316 })
317}
318
319fn parse_legacy_function_call_from_generated_text(
320 text: &str,
321 chat_request: &ApiChatRequest,
322) -> Option<ApiFunctionCall> {
323 let value = parse_json_value_from_generated_text(text)?;
324 let function = value.get("function_call").unwrap_or(&value);
325 let name = function.get("name").and_then(|value| value.as_str())?;
326 if !api_function_name_allowed(chat_request, name) {
327 return None;
328 }
329 Some(ApiFunctionCall {
330 name: name.to_string(),
331 arguments: api_arguments_to_string(function.get("arguments")),
332 })
333}
334
335fn api_tool_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
336 match chat_request.tool_choice.as_ref() {
337 Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
338 Some(ApiToolChoice::Function {
339 tool_type,
340 function,
341 }) => {
342 tool_type == "function"
343 && function.name == name
344 && chat_request
345 .tools
346 .iter()
347 .any(|tool| tool.function.name == name)
348 }
349 _ => chat_request
350 .tools
351 .iter()
352 .any(|tool| tool.function.name == name),
353 }
354}
355
356fn api_function_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
357 match chat_request.legacy_function_call.as_ref() {
358 Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
359 Some(ApiFunctionCallChoice::Function { name: selected }) => {
360 selected == name
361 && chat_request
362 .legacy_functions
363 .iter()
364 .any(|function| function.name == name)
365 }
366 _ => chat_request
367 .legacy_functions
368 .iter()
369 .any(|function| function.name == name),
370 }
371}
372
373fn parse_json_value_from_generated_text(text: &str) -> Option<serde_json::Value> {
374 let trimmed = strip_single_json_fence(text.trim());
375 serde_json::from_str(trimmed).ok().or_else(|| {
376 let start = trimmed.find('{')?;
377 let end = trimmed.rfind('}')?;
378 (start <= end)
379 .then(|| serde_json::from_str(&trimmed[start..=end]).ok())
380 .flatten()
381 })
382}
383
384fn strip_single_json_fence(text: &str) -> &str {
385 let Some(rest) = text.strip_prefix("```") else {
386 return text;
387 };
388 let rest = rest.strip_prefix("json").unwrap_or(rest).trim_start();
389 rest.strip_suffix("```").map(str::trim).unwrap_or(text)
390}
391
392fn api_arguments_to_string(arguments: Option<&serde_json::Value>) -> String {
393 match arguments {
394 Some(serde_json::Value::String(raw)) => raw.clone(),
395 Some(value) => serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()),
396 None => "{}".to_string(),
397 }
398}
399
400impl InferenceRequest {
401 pub fn new(prompt: impl Into<String>, model_id: impl Into<ModelId>) -> Self {
403 Self {
404 id: RequestId::new(),
405 prompt: prompt.into(),
406 model_id: model_id.into(),
407 sampling_params: SamplingParams::default(),
408 stream: false,
409 priority: Priority::default(),
410 client_id: None,
411 session_id: None,
412 created_at: Utc::now(),
413 api_request: None,
414 metadata: HashMap::new(),
415 }
416 }
417
418 pub fn with_sampling_params(mut self, params: SamplingParams) -> Self {
420 self.sampling_params = params;
421 self
422 }
423
424 pub fn with_stream(mut self, stream: bool) -> Self {
426 self.stream = stream;
427 self
428 }
429
430 pub fn with_priority(mut self, priority: Priority) -> Self {
432 self.priority = priority;
433 self
434 }
435
436 pub fn with_client_id(mut self, client_id: impl Into<ClientId>) -> Self {
438 self.client_id = Some(client_id.into());
439 self
440 }
441
442 pub fn with_session_id(mut self, session_id: SessionId) -> Self {
444 self.session_id = Some(session_id);
445 self
446 }
447
448 pub fn with_api_request(mut self, api_request: ApiRequest) -> Self {
450 self.api_request = Some(api_request);
451 self
452 }
453
454 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
456 self.metadata.insert(key.into(), value);
457 self
458 }
459}
460
461#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct InferenceResponse {
464 pub request_id: RequestId,
466 pub text: String,
468 pub tokens: Vec<TokenId>,
470 pub finish_reason: FinishReason,
472 pub usage: TokenUsage,
474 pub latency_ms: u64,
476 pub created_at: DateTime<Utc>,
478 pub metadata: HashMap<String, serde_json::Value>,
480 #[serde(default, skip_serializing_if = "Option::is_none")]
484 pub api_response: Option<ApiResponse>,
485}
486
487#[derive(Debug, Clone, Serialize, Deserialize)]
489pub struct StreamChunk {
490 pub request_id: RequestId,
492 pub text: String,
494 pub token: Option<TokenId>,
496 pub finish_reason: Option<FinishReason>,
498 pub usage: Option<TokenUsage>,
500 pub created_at: DateTime<Utc>,
502 pub metadata: HashMap<String, serde_json::Value>,
504 #[serde(default, skip_serializing_if = "Option::is_none")]
508 pub api_response: Option<ApiResponse>,
509}
510
511#[derive(Debug, Clone, Serialize, Deserialize)]
513pub struct BatchRequest {
514 pub batch_id: BatchId,
516 pub requests: Vec<InferenceRequest>,
518 pub max_sequence_length: usize,
520 pub created_at: DateTime<Utc>,
522}
523
524impl BatchRequest {
525 pub fn new(requests: Vec<InferenceRequest>) -> Self {
527 let max_sequence_length = requests
528 .iter()
529 .map(|r| r.sampling_params.max_tokens)
530 .max()
531 .unwrap_or(512);
532
533 Self {
534 batch_id: BatchId::new(),
535 requests,
536 max_sequence_length,
537 created_at: Utc::now(),
538 }
539 }
540
541 pub fn size(&self) -> usize {
543 self.requests.len()
544 }
545
546 pub fn is_empty(&self) -> bool {
548 self.requests.is_empty()
549 }
550}
551
552#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
554pub enum RequestState {
555 Waiting,
557 Running,
559 Preempted,
561 Completed,
563 Failed,
565 Cancelled,
567}
568
569#[derive(Debug, Clone)]
571pub struct ScheduledRequest {
572 pub request: InferenceRequest,
574 pub state: RequestState,
576 pub allocated_blocks: Vec<crate::BlockId>,
578 pub tokens_processed: usize,
580 pub estimated_completion: Option<DateTime<Utc>>,
582}
583
584impl ScheduledRequest {
585 pub fn new(request: InferenceRequest) -> Self {
587 Self {
588 request,
589 state: RequestState::Waiting,
590 allocated_blocks: Vec::new(),
591 tokens_processed: 0,
592 estimated_completion: None,
593 }
594 }
595
596 pub fn set_state(&mut self, state: RequestState) {
598 self.state = state;
599 }
600
601 pub fn add_blocks(&mut self, blocks: Vec<crate::BlockId>) {
603 self.allocated_blocks.extend(blocks);
604 }
605
606 pub fn update_progress(&mut self, tokens_processed: usize) {
608 self.tokens_processed = tokens_processed;
609 }
610}