1use crate::{ids::*, models::TokenUsage, FinishReason, Priority, SamplingParams, TokenId};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8pub const PROMPT_TOKENS_METADATA_KEY: &str = "ferrum_prompt_tokens";
9pub const DEFAULT_MAX_TOKENS_METADATA_KEY: &str = "ferrum_default_max_tokens";
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct InferenceRequest {
14 pub id: RequestId,
16 pub prompt: String,
18 pub model_id: ModelId,
20 pub sampling_params: SamplingParams,
22 pub stream: bool,
24 pub priority: Priority,
26 pub client_id: Option<ClientId>,
28 pub session_id: Option<SessionId>,
30 pub created_at: DateTime<Utc>,
32 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub api_request: Option<ApiRequest>,
37 pub metadata: HashMap<String, serde_json::Value>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42#[serde(tag = "kind", rename_all = "snake_case")]
43pub enum ApiRequest {
44 Chat(ApiChatRequest),
45 Completion(ApiCompletionRequest),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49#[serde(tag = "kind", rename_all = "snake_case")]
50pub enum ApiResponse {
51 Chat(ApiChatResponse),
52 Completion(ApiCompletionResponse),
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct ApiChatRequest {
57 pub messages: Vec<ApiChatMessage>,
58 #[serde(default, skip_serializing_if = "Vec::is_empty")]
59 pub tools: Vec<ApiTool>,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
61 pub tool_choice: Option<ApiToolChoice>,
62 #[serde(default, skip_serializing_if = "Vec::is_empty")]
63 pub legacy_functions: Vec<ApiFunction>,
64 #[serde(default, skip_serializing_if = "Option::is_none")]
65 pub legacy_function_call: Option<ApiFunctionCallChoice>,
66 #[serde(default, skip_serializing_if = "Option::is_none")]
67 pub response_format: Option<ApiResponseFormat>,
68 #[serde(default, skip_serializing_if = "Option::is_none")]
69 pub stream_options: Option<ApiStreamOptions>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct ApiCompletionRequest {
74 pub prompt: String,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
76 pub response_format: Option<ApiResponseFormat>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct ApiChatResponse {
81 pub message: ApiChatMessage,
82 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub finish_reason: Option<String>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
87pub struct ApiCompletionResponse {
88 pub text: String,
89 #[serde(default, skip_serializing_if = "Option::is_none")]
90 pub finish_reason: Option<String>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
94pub struct ApiChatMessage {
95 pub role: ApiMessageRole,
96 pub content: String,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub name: Option<String>,
99 #[serde(default, skip_serializing_if = "Vec::is_empty")]
100 pub tool_calls: Vec<ApiToolCall>,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub tool_call_id: Option<String>,
103 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub function_call: Option<ApiFunctionCall>,
105}
106
107#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
108#[serde(rename_all = "lowercase")]
109pub enum ApiMessageRole {
110 System,
111 User,
112 Assistant,
113 Function,
114 Tool,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
118pub struct ApiTool {
119 #[serde(rename = "type")]
120 pub tool_type: String,
121 pub function: ApiFunction,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
125pub struct ApiFunction {
126 pub name: String,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
128 pub description: Option<String>,
129 #[serde(default, skip_serializing_if = "Option::is_none")]
130 pub parameters: Option<serde_json::Value>,
131 #[serde(default, skip_serializing_if = "Option::is_none")]
132 pub strict: Option<bool>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
136#[serde(untagged)]
137pub enum ApiToolChoice {
138 Mode(String),
139 Function {
140 #[serde(rename = "type")]
141 tool_type: String,
142 function: ApiToolChoiceFunction,
143 },
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
147pub struct ApiToolChoiceFunction {
148 pub name: String,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
152#[serde(untagged)]
153pub enum ApiFunctionCallChoice {
154 Mode(String),
155 Function { name: String },
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
159pub struct ApiToolCall {
160 pub id: String,
161 #[serde(rename = "type")]
162 pub tool_type: String,
163 pub function: ApiFunctionCall,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
167pub struct ApiFunctionCall {
168 pub name: String,
169 pub arguments: String,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct ApiResponseFormat {
174 #[serde(rename = "type")]
175 pub format_type: String,
176 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub json_schema: Option<ApiJsonSchema>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
181pub struct ApiJsonSchema {
182 #[serde(default, skip_serializing_if = "Option::is_none")]
183 pub name: Option<String>,
184 pub schema: serde_json::Value,
185 #[serde(default, skip_serializing_if = "Option::is_none")]
186 pub strict: Option<bool>,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
190pub struct ApiStreamOptions {
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub include_usage: Option<bool>,
193}
194
195pub fn api_response_from_generated_text(
196 request: &InferenceRequest,
197 text: &str,
198) -> Option<ApiResponse> {
199 let ApiRequest::Chat(chat_request) = request.api_request.as_ref()? else {
200 return None;
201 };
202 chat_api_response_from_generated_text(chat_request, text).map(ApiResponse::Chat)
203}
204
205pub fn chat_api_may_emit_tool_or_function_call(chat_request: &ApiChatRequest) -> bool {
206 (!chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request))
207 || (!chat_request.legacy_functions.is_empty()
208 && !api_function_call_choice_is_none(chat_request))
209}
210
211pub fn chat_api_response_from_generated_text(
212 chat_request: &ApiChatRequest,
213 text: &str,
214) -> Option<ApiChatResponse> {
215 if !chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request) {
216 if let Some(tool_calls) = parse_tool_calls_from_generated_text(text, chat_request) {
217 return Some(ApiChatResponse {
218 message: ApiChatMessage {
219 role: ApiMessageRole::Assistant,
220 content: String::new(),
221 name: None,
222 tool_calls,
223 tool_call_id: None,
224 function_call: None,
225 },
226 finish_reason: Some("tool_calls".to_string()),
227 });
228 }
229 }
230
231 if !chat_request.legacy_functions.is_empty() && !api_function_call_choice_is_none(chat_request)
232 {
233 if let Some(function_call) =
234 parse_legacy_function_call_from_generated_text(text, chat_request)
235 {
236 return Some(ApiChatResponse {
237 message: ApiChatMessage {
238 role: ApiMessageRole::Assistant,
239 content: String::new(),
240 name: None,
241 tool_calls: Vec::new(),
242 tool_call_id: None,
243 function_call: Some(function_call),
244 },
245 finish_reason: Some("function_call".to_string()),
246 });
247 }
248 }
249
250 None
251}
252
253fn api_tool_choice_is_none(chat_request: &ApiChatRequest) -> bool {
254 matches!(
255 chat_request.tool_choice.as_ref(),
256 Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
257 )
258}
259
260fn api_function_call_choice_is_none(chat_request: &ApiChatRequest) -> bool {
261 matches!(
262 chat_request.legacy_function_call.as_ref(),
263 Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
264 )
265}
266
267fn parse_tool_calls_from_generated_text(
268 text: &str,
269 chat_request: &ApiChatRequest,
270) -> Option<Vec<ApiToolCall>> {
271 let value = parse_json_value_from_generated_text(text)?;
272 if let Some(calls) = value.get("tool_calls").and_then(|value| value.as_array()) {
273 let parsed = calls
274 .iter()
275 .enumerate()
276 .filter_map(|(index, value)| parse_tool_call_value(value, index, chat_request))
277 .collect::<Vec<_>>();
278 return (!parsed.is_empty()).then_some(parsed);
279 }
280 if let Some(tool_call) = value.get("tool_call") {
281 return parse_tool_call_value(tool_call, 0, chat_request).map(|call| vec![call]);
282 }
283 parse_tool_call_value(&value, 0, chat_request)
284 .or_else(|| parse_forced_tool_arguments_value(&value, 0, chat_request))
285 .map(|call| vec![call])
286}
287
288fn parse_tool_call_value(
289 value: &serde_json::Value,
290 index: usize,
291 chat_request: &ApiChatRequest,
292) -> Option<ApiToolCall> {
293 let tool_type = value
294 .get("type")
295 .and_then(|value| value.as_str())
296 .unwrap_or("function");
297 if tool_type != "function" {
298 return None;
299 }
300 let function = value.get("function").unwrap_or(value);
301 let name = function.get("name").and_then(|value| value.as_str())?;
302 if !api_tool_name_allowed(chat_request, name) {
303 return None;
304 }
305 let arguments = api_arguments_to_string(function.get("arguments"));
306 let id = value
307 .get("id")
308 .and_then(|value| value.as_str())
309 .map(str::to_string)
310 .unwrap_or_else(|| format!("call_{index}"));
311
312 Some(ApiToolCall {
313 id,
314 tool_type: "function".to_string(),
315 function: ApiFunctionCall {
316 name: name.to_string(),
317 arguments,
318 },
319 })
320}
321
322fn parse_forced_tool_arguments_value(
323 value: &serde_json::Value,
324 index: usize,
325 chat_request: &ApiChatRequest,
326) -> Option<ApiToolCall> {
327 let name = forced_tool_choice_name(chat_request)?;
328 if value.get("tool_calls").is_some()
329 || value.get("tool_call").is_some()
330 || value.get("function").is_some()
331 || value.get("name").is_some()
332 {
333 return None;
334 }
335
336 Some(ApiToolCall {
337 id: format!("call_{index}"),
338 tool_type: "function".to_string(),
339 function: ApiFunctionCall {
340 name: name.to_string(),
341 arguments: serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()),
342 },
343 })
344}
345
346fn forced_tool_choice_name(chat_request: &ApiChatRequest) -> Option<&str> {
347 match chat_request.tool_choice.as_ref() {
348 Some(ApiToolChoice::Function {
349 tool_type,
350 function,
351 }) if tool_type == "function" && api_tool_name_allowed(chat_request, &function.name) => {
352 Some(function.name.as_str())
353 }
354 _ => None,
355 }
356}
357
358fn parse_legacy_function_call_from_generated_text(
359 text: &str,
360 chat_request: &ApiChatRequest,
361) -> Option<ApiFunctionCall> {
362 let value = parse_json_value_from_generated_text(text)?;
363 let function = value.get("function_call").unwrap_or(&value);
364 let name = function.get("name").and_then(|value| value.as_str())?;
365 if !api_function_name_allowed(chat_request, name) {
366 return None;
367 }
368 Some(ApiFunctionCall {
369 name: name.to_string(),
370 arguments: api_arguments_to_string(function.get("arguments")),
371 })
372}
373
374fn api_tool_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
375 match chat_request.tool_choice.as_ref() {
376 Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
377 Some(ApiToolChoice::Function {
378 tool_type,
379 function,
380 }) => {
381 tool_type == "function"
382 && function.name == name
383 && chat_request
384 .tools
385 .iter()
386 .any(|tool| tool.function.name == name)
387 }
388 _ => chat_request
389 .tools
390 .iter()
391 .any(|tool| tool.function.name == name),
392 }
393}
394
395fn api_function_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
396 match chat_request.legacy_function_call.as_ref() {
397 Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
398 Some(ApiFunctionCallChoice::Function { name: selected }) => {
399 selected == name
400 && chat_request
401 .legacy_functions
402 .iter()
403 .any(|function| function.name == name)
404 }
405 _ => chat_request
406 .legacy_functions
407 .iter()
408 .any(|function| function.name == name),
409 }
410}
411
412fn parse_json_value_from_generated_text(text: &str) -> Option<serde_json::Value> {
413 let trimmed = strip_single_json_fence(text.trim());
414 serde_json::from_str(trimmed).ok().or_else(|| {
415 let start = trimmed.find('{')?;
416 let end = trimmed.rfind('}')?;
417 (start <= end)
418 .then(|| serde_json::from_str(&trimmed[start..=end]).ok())
419 .flatten()
420 })
421}
422
423fn strip_single_json_fence(text: &str) -> &str {
424 let Some(rest) = text.strip_prefix("```") else {
425 return text;
426 };
427 let rest = rest.strip_prefix("json").unwrap_or(rest).trim_start();
428 rest.strip_suffix("```").map(str::trim).unwrap_or(text)
429}
430
431fn api_arguments_to_string(arguments: Option<&serde_json::Value>) -> String {
432 match arguments {
433 Some(serde_json::Value::String(raw)) => raw.clone(),
434 Some(value) => serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()),
435 None => "{}".to_string(),
436 }
437}
438
439impl InferenceRequest {
440 pub fn new(prompt: impl Into<String>, model_id: impl Into<ModelId>) -> Self {
442 Self {
443 id: RequestId::new(),
444 prompt: prompt.into(),
445 model_id: model_id.into(),
446 sampling_params: SamplingParams::default(),
447 stream: false,
448 priority: Priority::default(),
449 client_id: None,
450 session_id: None,
451 created_at: Utc::now(),
452 api_request: None,
453 metadata: HashMap::new(),
454 }
455 }
456
457 pub fn with_sampling_params(mut self, params: SamplingParams) -> Self {
459 self.sampling_params = params;
460 self
461 }
462
463 pub fn with_stream(mut self, stream: bool) -> Self {
465 self.stream = stream;
466 self
467 }
468
469 pub fn with_priority(mut self, priority: Priority) -> Self {
471 self.priority = priority;
472 self
473 }
474
475 pub fn with_client_id(mut self, client_id: impl Into<ClientId>) -> Self {
477 self.client_id = Some(client_id.into());
478 self
479 }
480
481 pub fn with_session_id(mut self, session_id: SessionId) -> Self {
483 self.session_id = Some(session_id);
484 self
485 }
486
487 pub fn with_api_request(mut self, api_request: ApiRequest) -> Self {
489 self.api_request = Some(api_request);
490 self
491 }
492
493 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
495 self.metadata.insert(key.into(), value);
496 self
497 }
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct InferenceResponse {
503 pub request_id: RequestId,
505 pub text: String,
507 pub tokens: Vec<TokenId>,
509 pub finish_reason: FinishReason,
511 pub usage: TokenUsage,
513 pub latency_ms: u64,
515 pub created_at: DateTime<Utc>,
517 pub metadata: HashMap<String, serde_json::Value>,
519 #[serde(default, skip_serializing_if = "Option::is_none")]
523 pub api_response: Option<ApiResponse>,
524}
525
526#[derive(Debug, Clone, Serialize, Deserialize)]
528pub struct StreamChunk {
529 pub request_id: RequestId,
531 pub text: String,
533 pub token: Option<TokenId>,
535 pub finish_reason: Option<FinishReason>,
537 pub usage: Option<TokenUsage>,
539 pub created_at: DateTime<Utc>,
541 pub metadata: HashMap<String, serde_json::Value>,
543 #[serde(default, skip_serializing_if = "Option::is_none")]
547 pub api_response: Option<ApiResponse>,
548}
549
550#[derive(Debug, Clone, Serialize, Deserialize)]
552pub struct BatchRequest {
553 pub batch_id: BatchId,
555 pub requests: Vec<InferenceRequest>,
557 pub max_sequence_length: usize,
559 pub created_at: DateTime<Utc>,
561}
562
563impl BatchRequest {
564 pub fn new(requests: Vec<InferenceRequest>) -> Self {
566 let max_sequence_length = requests
567 .iter()
568 .map(|r| r.sampling_params.max_tokens)
569 .max()
570 .unwrap_or(512);
571
572 Self {
573 batch_id: BatchId::new(),
574 requests,
575 max_sequence_length,
576 created_at: Utc::now(),
577 }
578 }
579
580 pub fn size(&self) -> usize {
582 self.requests.len()
583 }
584
585 pub fn is_empty(&self) -> bool {
587 self.requests.is_empty()
588 }
589}
590
591#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
593pub enum RequestState {
594 Waiting,
596 Running,
598 Preempted,
600 Completed,
602 Failed,
604 Cancelled,
606}
607
608#[derive(Debug, Clone)]
610pub struct ScheduledRequest {
611 pub request: InferenceRequest,
613 pub state: RequestState,
615 pub allocated_blocks: Vec<crate::BlockId>,
617 pub tokens_processed: usize,
619 pub estimated_completion: Option<DateTime<Utc>>,
621}
622
623impl ScheduledRequest {
624 pub fn new(request: InferenceRequest) -> Self {
626 Self {
627 request,
628 state: RequestState::Waiting,
629 allocated_blocks: Vec::new(),
630 tokens_processed: 0,
631 estimated_completion: None,
632 }
633 }
634
635 pub fn set_state(&mut self, state: RequestState) {
637 self.state = state;
638 }
639
640 pub fn add_blocks(&mut self, blocks: Vec<crate::BlockId>) {
642 self.allocated_blocks.extend(blocks);
643 }
644
645 pub fn update_progress(&mut self, tokens_processed: usize) {
647 self.tokens_processed = tokens_processed;
648 }
649}