1use crate::{ids::*, models::TokenUsage, FinishReason, Priority, SamplingParams, TokenId};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8pub const PROMPT_TOKENS_METADATA_KEY: &str = "ferrum_prompt_tokens";
9pub const DEFAULT_MAX_TOKENS_METADATA_KEY: &str = "ferrum_default_max_tokens";
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct InferenceRequest {
14 pub id: RequestId,
16 pub prompt: String,
18 pub model_id: ModelId,
20 pub sampling_params: SamplingParams,
22 pub stream: bool,
24 pub priority: Priority,
26 pub client_id: Option<ClientId>,
28 pub session_id: Option<SessionId>,
30 pub created_at: DateTime<Utc>,
32 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub api_request: Option<ApiRequest>,
37 pub metadata: HashMap<String, serde_json::Value>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42#[serde(tag = "kind", rename_all = "snake_case")]
43pub enum ApiRequest {
44 Chat(ApiChatRequest),
45 Completion(ApiCompletionRequest),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49#[serde(tag = "kind", rename_all = "snake_case")]
50pub enum ApiResponse {
51 Chat(ApiChatResponse),
52 Completion(ApiCompletionResponse),
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct ApiChatRequest {
57 pub messages: Vec<ApiChatMessage>,
58 #[serde(default, skip_serializing_if = "Vec::is_empty")]
59 pub tools: Vec<ApiTool>,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
61 pub tool_choice: Option<ApiToolChoice>,
62 #[serde(default, skip_serializing_if = "Vec::is_empty")]
63 pub legacy_functions: Vec<ApiFunction>,
64 #[serde(default, skip_serializing_if = "Option::is_none")]
65 pub legacy_function_call: Option<ApiFunctionCallChoice>,
66 #[serde(default, skip_serializing_if = "Option::is_none")]
67 pub response_format: Option<ApiResponseFormat>,
68 #[serde(default, skip_serializing_if = "Option::is_none")]
69 pub stream_options: Option<ApiStreamOptions>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct ApiCompletionRequest {
74 pub prompt: String,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
76 pub response_format: Option<ApiResponseFormat>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct ApiChatResponse {
81 pub message: ApiChatMessage,
82 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub finish_reason: Option<String>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
87pub struct ApiCompletionResponse {
88 pub text: String,
89 #[serde(default, skip_serializing_if = "Option::is_none")]
90 pub finish_reason: Option<String>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
94pub struct ApiChatMessage {
95 pub role: ApiMessageRole,
96 pub content: String,
97 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub name: Option<String>,
99 #[serde(default, skip_serializing_if = "Vec::is_empty")]
100 pub tool_calls: Vec<ApiToolCall>,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub tool_call_id: Option<String>,
103 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub function_call: Option<ApiFunctionCall>,
105}
106
107#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
108#[serde(rename_all = "lowercase")]
109pub enum ApiMessageRole {
110 System,
111 User,
112 Assistant,
113 Function,
114 Tool,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
118pub struct ApiTool {
119 #[serde(rename = "type")]
120 pub tool_type: String,
121 pub function: ApiFunction,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
125pub struct ApiFunction {
126 pub name: String,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
128 pub description: Option<String>,
129 #[serde(default, skip_serializing_if = "Option::is_none")]
130 pub parameters: Option<serde_json::Value>,
131 #[serde(default, skip_serializing_if = "Option::is_none")]
132 pub strict: Option<bool>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
136#[serde(untagged)]
137pub enum ApiToolChoice {
138 Mode(String),
139 Function {
140 #[serde(rename = "type")]
141 tool_type: String,
142 function: ApiToolChoiceFunction,
143 },
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
147pub struct ApiToolChoiceFunction {
148 pub name: String,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
152#[serde(untagged)]
153pub enum ApiFunctionCallChoice {
154 Mode(String),
155 Function { name: String },
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
159pub struct ApiToolCall {
160 pub id: String,
161 #[serde(rename = "type")]
162 pub tool_type: String,
163 pub function: ApiFunctionCall,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
167pub struct ApiFunctionCall {
168 pub name: String,
169 pub arguments: String,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct ApiResponseFormat {
174 #[serde(rename = "type")]
175 pub format_type: String,
176 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub json_schema: Option<ApiJsonSchema>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
181pub struct ApiJsonSchema {
182 #[serde(default, skip_serializing_if = "Option::is_none")]
183 pub name: Option<String>,
184 pub schema: serde_json::Value,
185 #[serde(default, skip_serializing_if = "Option::is_none")]
186 pub strict: Option<bool>,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
190pub struct ApiStreamOptions {
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub include_usage: Option<bool>,
193}
194
195pub fn api_response_from_generated_text(
196 request: &InferenceRequest,
197 text: &str,
198) -> Option<ApiResponse> {
199 let ApiRequest::Chat(chat_request) = request.api_request.as_ref()? else {
200 return None;
201 };
202 chat_api_response_from_generated_text(chat_request, text).map(ApiResponse::Chat)
203}
204
205pub fn chat_api_may_emit_tool_or_function_call(chat_request: &ApiChatRequest) -> bool {
206 (!chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request))
207 || (!chat_request.legacy_functions.is_empty()
208 && !api_function_call_choice_is_none(chat_request))
209}
210
211pub fn chat_api_response_from_generated_text(
212 chat_request: &ApiChatRequest,
213 text: &str,
214) -> Option<ApiChatResponse> {
215 if !chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request) {
216 if let Some(tool_calls) = parse_tool_calls_from_generated_text(text, chat_request) {
217 return Some(ApiChatResponse {
218 message: ApiChatMessage {
219 role: ApiMessageRole::Assistant,
220 content: String::new(),
221 name: None,
222 tool_calls,
223 tool_call_id: None,
224 function_call: None,
225 },
226 finish_reason: Some("tool_calls".to_string()),
227 });
228 }
229 }
230
231 if !chat_request.legacy_functions.is_empty() && !api_function_call_choice_is_none(chat_request)
232 {
233 if let Some(function_call) =
234 parse_legacy_function_call_from_generated_text(text, chat_request)
235 {
236 return Some(ApiChatResponse {
237 message: ApiChatMessage {
238 role: ApiMessageRole::Assistant,
239 content: String::new(),
240 name: None,
241 tool_calls: Vec::new(),
242 tool_call_id: None,
243 function_call: Some(function_call),
244 },
245 finish_reason: Some("function_call".to_string()),
246 });
247 }
248 }
249
250 None
251}
252
253fn api_tool_choice_is_none(chat_request: &ApiChatRequest) -> bool {
254 matches!(
255 chat_request.tool_choice.as_ref(),
256 Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
257 )
258}
259
260fn api_function_call_choice_is_none(chat_request: &ApiChatRequest) -> bool {
261 matches!(
262 chat_request.legacy_function_call.as_ref(),
263 Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
264 )
265}
266
267fn parse_tool_calls_from_generated_text(
268 text: &str,
269 chat_request: &ApiChatRequest,
270) -> Option<Vec<ApiToolCall>> {
271 let value = parse_json_value_from_generated_text(text)?;
272 if let Some(calls) = value.get("tool_calls").and_then(|value| value.as_array()) {
273 let parsed = calls
274 .iter()
275 .enumerate()
276 .filter_map(|(index, value)| parse_tool_call_value(value, index, chat_request))
277 .collect::<Vec<_>>();
278 return (!parsed.is_empty()).then_some(parsed);
279 }
280 if let Some(tool_call) = value.get("tool_call") {
281 return parse_tool_call_value(tool_call, 0, chat_request).map(|call| vec![call]);
282 }
283 if let Some(tool_call) = parse_wrapped_tool_call_value(&value, 0, chat_request) {
284 return Some(vec![tool_call]);
285 }
286 parse_tool_call_value(&value, 0, chat_request)
287 .or_else(|| parse_forced_tool_arguments_value(&value, 0, chat_request))
288 .map(|call| vec![call])
289}
290
291fn parse_wrapped_tool_call_value(
292 value: &serde_json::Value,
293 index: usize,
294 chat_request: &ApiChatRequest,
295) -> Option<ApiToolCall> {
296 for key in ["auto", "tool", "tool_call", "auto_tool_response"] {
297 if let Some(wrapped) = value.get(key) {
298 if let Some(call) = parse_tool_call_value(wrapped, index, chat_request) {
299 return Some(call);
300 }
301 }
302 }
303 None
304}
305
306fn parse_tool_call_value(
307 value: &serde_json::Value,
308 index: usize,
309 chat_request: &ApiChatRequest,
310) -> Option<ApiToolCall> {
311 let tool_type = value
312 .get("type")
313 .and_then(|value| value.as_str())
314 .unwrap_or("function");
315 if tool_type != "function" {
316 return None;
317 }
318 let function = value.get("function").unwrap_or(value);
319 let name = function
320 .as_str()
321 .or_else(|| function.get("name").and_then(|value| value.as_str()))
322 .or_else(|| function.get("tool").and_then(|value| value.as_str()))
323 .or_else(|| value.get("name").and_then(|value| value.as_str()))?;
324 if !api_tool_name_allowed(chat_request, name) {
325 return None;
326 }
327 let arguments = api_arguments_to_string(
328 function
329 .get("arguments")
330 .or_else(|| function.get("parameters"))
331 .or_else(|| value.get("arguments"))
332 .or_else(|| value.get("parameters")),
333 );
334 let id = value
335 .get("id")
336 .and_then(|value| value.as_str())
337 .map(str::to_string)
338 .unwrap_or_else(|| format!("call_{index}"));
339
340 Some(ApiToolCall {
341 id,
342 tool_type: "function".to_string(),
343 function: ApiFunctionCall {
344 name: name.to_string(),
345 arguments,
346 },
347 })
348}
349
350fn parse_forced_tool_arguments_value(
351 value: &serde_json::Value,
352 index: usize,
353 chat_request: &ApiChatRequest,
354) -> Option<ApiToolCall> {
355 let tool = unwrapped_tool_arguments_target(chat_request, value)?;
356 if value.get("tool_calls").is_some()
357 || value.get("tool_call").is_some()
358 || value.get("function").is_some()
359 || value.get("name").is_some()
360 {
361 return None;
362 }
363
364 Some(ApiToolCall {
365 id: format!("call_{index}"),
366 tool_type: "function".to_string(),
367 function: ApiFunctionCall {
368 name: tool.function.name.clone(),
369 arguments: serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()),
370 },
371 })
372}
373
374fn unwrapped_tool_arguments_target<'a>(
375 chat_request: &'a ApiChatRequest,
376 value: &serde_json::Value,
377) -> Option<&'a ApiTool> {
378 if let Some(name) = forced_tool_choice_name(chat_request) {
379 return chat_request
380 .tools
381 .iter()
382 .find(|tool| tool.tool_type == "function" && tool.function.name == name);
383 }
384
385 if matches!(
386 chat_request.tool_choice.as_ref(),
387 Some(ApiToolChoice::Mode(mode)) if !mode.eq_ignore_ascii_case("auto")
388 ) {
389 return None;
390 }
391
392 let mut function_tools = chat_request
393 .tools
394 .iter()
395 .filter(|tool| tool.tool_type == "function");
396 let tool = function_tools.next()?;
397 if function_tools.next().is_some() || !value_looks_like_tool_arguments(value, tool) {
398 return None;
399 }
400 Some(tool)
401}
402
403fn value_looks_like_tool_arguments(value: &serde_json::Value, tool: &ApiTool) -> bool {
404 let Some(arguments) = value.as_object() else {
405 return false;
406 };
407 if arguments.is_empty() {
408 return false;
409 }
410 let Some(properties) = tool
411 .function
412 .parameters
413 .as_ref()
414 .and_then(|parameters| parameters.get("properties"))
415 .and_then(|properties| properties.as_object())
416 else {
417 return false;
418 };
419 arguments.keys().all(|key| properties.contains_key(key))
420}
421
422fn forced_tool_choice_name(chat_request: &ApiChatRequest) -> Option<&str> {
423 match chat_request.tool_choice.as_ref() {
424 Some(ApiToolChoice::Function {
425 tool_type,
426 function,
427 }) if tool_type == "function" && api_tool_name_allowed(chat_request, &function.name) => {
428 Some(function.name.as_str())
429 }
430 Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("required") => chat_request
431 .tools
432 .first()
433 .map(|tool| tool.function.name.as_str()),
434 _ => None,
435 }
436}
437
438fn parse_legacy_function_call_from_generated_text(
439 text: &str,
440 chat_request: &ApiChatRequest,
441) -> Option<ApiFunctionCall> {
442 let value = parse_json_value_from_generated_text(text)?;
443 let function = value.get("function_call").unwrap_or(&value);
444 let name = function.get("name").and_then(|value| value.as_str())?;
445 if !api_function_name_allowed(chat_request, name) {
446 return None;
447 }
448 Some(ApiFunctionCall {
449 name: name.to_string(),
450 arguments: api_arguments_to_string(function.get("arguments")),
451 })
452}
453
454fn api_tool_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
455 match chat_request.tool_choice.as_ref() {
456 Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
457 Some(ApiToolChoice::Function {
458 tool_type,
459 function,
460 }) => {
461 tool_type == "function"
462 && function.name == name
463 && chat_request
464 .tools
465 .iter()
466 .any(|tool| tool.function.name == name)
467 }
468 _ => chat_request
469 .tools
470 .iter()
471 .any(|tool| tool.function.name == name),
472 }
473}
474
475fn api_function_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
476 match chat_request.legacy_function_call.as_ref() {
477 Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
478 Some(ApiFunctionCallChoice::Function { name: selected }) => {
479 selected == name
480 && chat_request
481 .legacy_functions
482 .iter()
483 .any(|function| function.name == name)
484 }
485 _ => chat_request
486 .legacy_functions
487 .iter()
488 .any(|function| function.name == name),
489 }
490}
491
492fn parse_json_value_from_generated_text(text: &str) -> Option<serde_json::Value> {
493 let trimmed = strip_single_json_fence(text.trim());
494 serde_json::from_str(trimmed).ok().or_else(|| {
495 let start = trimmed.find('{')?;
496 let end = trimmed.rfind('}')?;
497 (start <= end)
498 .then(|| serde_json::from_str(&trimmed[start..=end]).ok())
499 .flatten()
500 })
501}
502
503fn strip_single_json_fence(text: &str) -> &str {
504 let Some(rest) = text.strip_prefix("```") else {
505 return text;
506 };
507 let rest = rest.strip_prefix("json").unwrap_or(rest).trim_start();
508 rest.strip_suffix("```").map(str::trim).unwrap_or(text)
509}
510
511fn api_arguments_to_string(arguments: Option<&serde_json::Value>) -> String {
512 match arguments {
513 Some(serde_json::Value::String(raw)) => raw.clone(),
514 Some(value) => serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()),
515 None => "{}".to_string(),
516 }
517}
518
519impl InferenceRequest {
520 pub fn new(prompt: impl Into<String>, model_id: impl Into<ModelId>) -> Self {
522 Self {
523 id: RequestId::new(),
524 prompt: prompt.into(),
525 model_id: model_id.into(),
526 sampling_params: SamplingParams::default(),
527 stream: false,
528 priority: Priority::default(),
529 client_id: None,
530 session_id: None,
531 created_at: Utc::now(),
532 api_request: None,
533 metadata: HashMap::new(),
534 }
535 }
536
537 pub fn with_sampling_params(mut self, params: SamplingParams) -> Self {
539 self.sampling_params = params;
540 self
541 }
542
543 pub fn with_stream(mut self, stream: bool) -> Self {
545 self.stream = stream;
546 self
547 }
548
549 pub fn with_priority(mut self, priority: Priority) -> Self {
551 self.priority = priority;
552 self
553 }
554
555 pub fn with_client_id(mut self, client_id: impl Into<ClientId>) -> Self {
557 self.client_id = Some(client_id.into());
558 self
559 }
560
561 pub fn with_session_id(mut self, session_id: SessionId) -> Self {
563 self.session_id = Some(session_id);
564 self
565 }
566
567 pub fn with_api_request(mut self, api_request: ApiRequest) -> Self {
569 self.api_request = Some(api_request);
570 self
571 }
572
573 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
575 self.metadata.insert(key.into(), value);
576 self
577 }
578}
579
580#[derive(Debug, Clone, Serialize, Deserialize)]
582pub struct InferenceResponse {
583 pub request_id: RequestId,
585 pub text: String,
587 pub tokens: Vec<TokenId>,
589 pub finish_reason: FinishReason,
591 pub usage: TokenUsage,
593 pub latency_ms: u64,
595 pub created_at: DateTime<Utc>,
597 pub metadata: HashMap<String, serde_json::Value>,
599 #[serde(default, skip_serializing_if = "Option::is_none")]
603 pub api_response: Option<ApiResponse>,
604}
605
606#[derive(Debug, Clone, Serialize, Deserialize)]
608pub struct StreamChunk {
609 pub request_id: RequestId,
611 pub text: String,
613 pub token: Option<TokenId>,
615 pub finish_reason: Option<FinishReason>,
617 pub usage: Option<TokenUsage>,
619 pub created_at: DateTime<Utc>,
621 pub metadata: HashMap<String, serde_json::Value>,
623 #[serde(default, skip_serializing_if = "Option::is_none")]
627 pub api_response: Option<ApiResponse>,
628}
629
630#[derive(Debug, Clone, Serialize, Deserialize)]
632pub struct BatchRequest {
633 pub batch_id: BatchId,
635 pub requests: Vec<InferenceRequest>,
637 pub max_sequence_length: usize,
639 pub created_at: DateTime<Utc>,
641}
642
643impl BatchRequest {
644 pub fn new(requests: Vec<InferenceRequest>) -> Self {
646 let max_sequence_length = requests
647 .iter()
648 .map(|r| r.sampling_params.max_tokens)
649 .max()
650 .unwrap_or(512);
651
652 Self {
653 batch_id: BatchId::new(),
654 requests,
655 max_sequence_length,
656 created_at: Utc::now(),
657 }
658 }
659
660 pub fn size(&self) -> usize {
662 self.requests.len()
663 }
664
665 pub fn is_empty(&self) -> bool {
667 self.requests.is_empty()
668 }
669}
670
671#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
673pub enum RequestState {
674 Waiting,
676 Running,
678 Preempted,
680 Completed,
682 Failed,
684 Cancelled,
686}
687
688#[derive(Debug, Clone)]
690pub struct ScheduledRequest {
691 pub request: InferenceRequest,
693 pub state: RequestState,
695 pub allocated_blocks: Vec<crate::BlockId>,
697 pub tokens_processed: usize,
699 pub estimated_completion: Option<DateTime<Utc>>,
701}
702
703impl ScheduledRequest {
704 pub fn new(request: InferenceRequest) -> Self {
706 Self {
707 request,
708 state: RequestState::Waiting,
709 allocated_blocks: Vec::new(),
710 tokens_processed: 0,
711 estimated_completion: None,
712 }
713 }
714
715 pub fn set_state(&mut self, state: RequestState) {
717 self.state = state;
718 }
719
720 pub fn add_blocks(&mut self, blocks: Vec<crate::BlockId>) {
722 self.allocated_blocks.extend(blocks);
723 }
724
725 pub fn update_progress(&mut self, tokens_processed: usize) {
727 self.tokens_processed = tokens_processed;
728 }
729}