Skip to main content

ferrum_types/
requests.rs

1//! Request and response types for inference
2
3use crate::{ids::*, models::TokenUsage, FinishReason, Priority, SamplingParams, TokenId};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8pub const PROMPT_TOKENS_METADATA_KEY: &str = "ferrum_prompt_tokens";
9pub const DEFAULT_MAX_TOKENS_METADATA_KEY: &str = "ferrum_default_max_tokens";
10
11/// Inference request
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct InferenceRequest {
14    /// Unique request identifier
15    pub id: RequestId,
16    /// Input prompt text
17    pub prompt: String,
18    /// Model to use for inference
19    pub model_id: ModelId,
20    /// Sampling parameters
21    pub sampling_params: SamplingParams,
22    /// Whether to stream response
23    pub stream: bool,
24    /// Request priority
25    pub priority: Priority,
26    /// Client identifier
27    pub client_id: Option<ClientId>,
28    /// Session identifier for stateful interactions
29    pub session_id: Option<SessionId>,
30    /// Request creation timestamp
31    pub created_at: DateTime<Utc>,
32    /// Structured product/API request context. `prompt` remains the rendered
33    /// model input for current engines; this carries the original semantic
34    /// request boundary for API features such as tools and response formats.
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub api_request: Option<ApiRequest>,
37    /// Additional metadata
38    pub metadata: HashMap<String, serde_json::Value>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42#[serde(tag = "kind", rename_all = "snake_case")]
43pub enum ApiRequest {
44    Chat(ApiChatRequest),
45    Completion(ApiCompletionRequest),
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49#[serde(tag = "kind", rename_all = "snake_case")]
50pub enum ApiResponse {
51    Chat(ApiChatResponse),
52    Completion(ApiCompletionResponse),
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct ApiChatRequest {
57    pub messages: Vec<ApiChatMessage>,
58    #[serde(default, skip_serializing_if = "Vec::is_empty")]
59    pub tools: Vec<ApiTool>,
60    #[serde(default, skip_serializing_if = "Option::is_none")]
61    pub tool_choice: Option<ApiToolChoice>,
62    #[serde(default, skip_serializing_if = "Vec::is_empty")]
63    pub legacy_functions: Vec<ApiFunction>,
64    #[serde(default, skip_serializing_if = "Option::is_none")]
65    pub legacy_function_call: Option<ApiFunctionCallChoice>,
66    #[serde(default, skip_serializing_if = "Option::is_none")]
67    pub response_format: Option<ApiResponseFormat>,
68    #[serde(default, skip_serializing_if = "Option::is_none")]
69    pub stream_options: Option<ApiStreamOptions>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct ApiCompletionRequest {
74    pub prompt: String,
75    #[serde(default, skip_serializing_if = "Option::is_none")]
76    pub response_format: Option<ApiResponseFormat>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct ApiChatResponse {
81    pub message: ApiChatMessage,
82    #[serde(default, skip_serializing_if = "Option::is_none")]
83    pub finish_reason: Option<String>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
87pub struct ApiCompletionResponse {
88    pub text: String,
89    #[serde(default, skip_serializing_if = "Option::is_none")]
90    pub finish_reason: Option<String>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
94pub struct ApiChatMessage {
95    pub role: ApiMessageRole,
96    pub content: String,
97    #[serde(default, skip_serializing_if = "Option::is_none")]
98    pub name: Option<String>,
99    #[serde(default, skip_serializing_if = "Vec::is_empty")]
100    pub tool_calls: Vec<ApiToolCall>,
101    #[serde(default, skip_serializing_if = "Option::is_none")]
102    pub tool_call_id: Option<String>,
103    #[serde(default, skip_serializing_if = "Option::is_none")]
104    pub function_call: Option<ApiFunctionCall>,
105}
106
107#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
108#[serde(rename_all = "lowercase")]
109pub enum ApiMessageRole {
110    System,
111    User,
112    Assistant,
113    Function,
114    Tool,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
118pub struct ApiTool {
119    #[serde(rename = "type")]
120    pub tool_type: String,
121    pub function: ApiFunction,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
125pub struct ApiFunction {
126    pub name: String,
127    #[serde(default, skip_serializing_if = "Option::is_none")]
128    pub description: Option<String>,
129    #[serde(default, skip_serializing_if = "Option::is_none")]
130    pub parameters: Option<serde_json::Value>,
131    #[serde(default, skip_serializing_if = "Option::is_none")]
132    pub strict: Option<bool>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
136#[serde(untagged)]
137pub enum ApiToolChoice {
138    Mode(String),
139    Function {
140        #[serde(rename = "type")]
141        tool_type: String,
142        function: ApiToolChoiceFunction,
143    },
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
147pub struct ApiToolChoiceFunction {
148    pub name: String,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
152#[serde(untagged)]
153pub enum ApiFunctionCallChoice {
154    Mode(String),
155    Function { name: String },
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
159pub struct ApiToolCall {
160    pub id: String,
161    #[serde(rename = "type")]
162    pub tool_type: String,
163    pub function: ApiFunctionCall,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
167pub struct ApiFunctionCall {
168    pub name: String,
169    pub arguments: String,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct ApiResponseFormat {
174    #[serde(rename = "type")]
175    pub format_type: String,
176    #[serde(default, skip_serializing_if = "Option::is_none")]
177    pub json_schema: Option<ApiJsonSchema>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
181pub struct ApiJsonSchema {
182    #[serde(default, skip_serializing_if = "Option::is_none")]
183    pub name: Option<String>,
184    pub schema: serde_json::Value,
185    #[serde(default, skip_serializing_if = "Option::is_none")]
186    pub strict: Option<bool>,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
190pub struct ApiStreamOptions {
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub include_usage: Option<bool>,
193}
194
195pub fn api_response_from_generated_text(
196    request: &InferenceRequest,
197    text: &str,
198) -> Option<ApiResponse> {
199    let ApiRequest::Chat(chat_request) = request.api_request.as_ref()? else {
200        return None;
201    };
202    chat_api_response_from_generated_text(chat_request, text).map(ApiResponse::Chat)
203}
204
205pub fn chat_api_may_emit_tool_or_function_call(chat_request: &ApiChatRequest) -> bool {
206    (!chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request))
207        || (!chat_request.legacy_functions.is_empty()
208            && !api_function_call_choice_is_none(chat_request))
209}
210
211pub fn chat_api_response_from_generated_text(
212    chat_request: &ApiChatRequest,
213    text: &str,
214) -> Option<ApiChatResponse> {
215    if !chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request) {
216        if let Some(tool_calls) = parse_tool_calls_from_generated_text(text, chat_request) {
217            return Some(ApiChatResponse {
218                message: ApiChatMessage {
219                    role: ApiMessageRole::Assistant,
220                    content: String::new(),
221                    name: None,
222                    tool_calls,
223                    tool_call_id: None,
224                    function_call: None,
225                },
226                finish_reason: Some("tool_calls".to_string()),
227            });
228        }
229    }
230
231    if !chat_request.legacy_functions.is_empty() && !api_function_call_choice_is_none(chat_request)
232    {
233        if let Some(function_call) =
234            parse_legacy_function_call_from_generated_text(text, chat_request)
235        {
236            return Some(ApiChatResponse {
237                message: ApiChatMessage {
238                    role: ApiMessageRole::Assistant,
239                    content: String::new(),
240                    name: None,
241                    tool_calls: Vec::new(),
242                    tool_call_id: None,
243                    function_call: Some(function_call),
244                },
245                finish_reason: Some("function_call".to_string()),
246            });
247        }
248    }
249
250    None
251}
252
253fn api_tool_choice_is_none(chat_request: &ApiChatRequest) -> bool {
254    matches!(
255        chat_request.tool_choice.as_ref(),
256        Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
257    )
258}
259
260fn api_function_call_choice_is_none(chat_request: &ApiChatRequest) -> bool {
261    matches!(
262        chat_request.legacy_function_call.as_ref(),
263        Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
264    )
265}
266
267fn parse_tool_calls_from_generated_text(
268    text: &str,
269    chat_request: &ApiChatRequest,
270) -> Option<Vec<ApiToolCall>> {
271    let value = parse_json_value_from_generated_text(text)?;
272    if let Some(calls) = value.get("tool_calls").and_then(|value| value.as_array()) {
273        let parsed = calls
274            .iter()
275            .enumerate()
276            .filter_map(|(index, value)| parse_tool_call_value(value, index, chat_request))
277            .collect::<Vec<_>>();
278        return (!parsed.is_empty()).then_some(parsed);
279    }
280    if let Some(tool_call) = value.get("tool_call") {
281        return parse_tool_call_value(tool_call, 0, chat_request).map(|call| vec![call]);
282    }
283    parse_tool_call_value(&value, 0, chat_request)
284        .or_else(|| parse_forced_tool_arguments_value(&value, 0, chat_request))
285        .map(|call| vec![call])
286}
287
288fn parse_tool_call_value(
289    value: &serde_json::Value,
290    index: usize,
291    chat_request: &ApiChatRequest,
292) -> Option<ApiToolCall> {
293    let tool_type = value
294        .get("type")
295        .and_then(|value| value.as_str())
296        .unwrap_or("function");
297    if tool_type != "function" {
298        return None;
299    }
300    let function = value.get("function").unwrap_or(value);
301    let name = function.get("name").and_then(|value| value.as_str())?;
302    if !api_tool_name_allowed(chat_request, name) {
303        return None;
304    }
305    let arguments = api_arguments_to_string(function.get("arguments"));
306    let id = value
307        .get("id")
308        .and_then(|value| value.as_str())
309        .map(str::to_string)
310        .unwrap_or_else(|| format!("call_{index}"));
311
312    Some(ApiToolCall {
313        id,
314        tool_type: "function".to_string(),
315        function: ApiFunctionCall {
316            name: name.to_string(),
317            arguments,
318        },
319    })
320}
321
322fn parse_forced_tool_arguments_value(
323    value: &serde_json::Value,
324    index: usize,
325    chat_request: &ApiChatRequest,
326) -> Option<ApiToolCall> {
327    let name = forced_tool_choice_name(chat_request)?;
328    if value.get("tool_calls").is_some()
329        || value.get("tool_call").is_some()
330        || value.get("function").is_some()
331        || value.get("name").is_some()
332    {
333        return None;
334    }
335
336    Some(ApiToolCall {
337        id: format!("call_{index}"),
338        tool_type: "function".to_string(),
339        function: ApiFunctionCall {
340            name: name.to_string(),
341            arguments: serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()),
342        },
343    })
344}
345
346fn forced_tool_choice_name(chat_request: &ApiChatRequest) -> Option<&str> {
347    match chat_request.tool_choice.as_ref() {
348        Some(ApiToolChoice::Function {
349            tool_type,
350            function,
351        }) if tool_type == "function" && api_tool_name_allowed(chat_request, &function.name) => {
352            Some(function.name.as_str())
353        }
354        _ => None,
355    }
356}
357
358fn parse_legacy_function_call_from_generated_text(
359    text: &str,
360    chat_request: &ApiChatRequest,
361) -> Option<ApiFunctionCall> {
362    let value = parse_json_value_from_generated_text(text)?;
363    let function = value.get("function_call").unwrap_or(&value);
364    let name = function.get("name").and_then(|value| value.as_str())?;
365    if !api_function_name_allowed(chat_request, name) {
366        return None;
367    }
368    Some(ApiFunctionCall {
369        name: name.to_string(),
370        arguments: api_arguments_to_string(function.get("arguments")),
371    })
372}
373
374fn api_tool_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
375    match chat_request.tool_choice.as_ref() {
376        Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
377        Some(ApiToolChoice::Function {
378            tool_type,
379            function,
380        }) => {
381            tool_type == "function"
382                && function.name == name
383                && chat_request
384                    .tools
385                    .iter()
386                    .any(|tool| tool.function.name == name)
387        }
388        _ => chat_request
389            .tools
390            .iter()
391            .any(|tool| tool.function.name == name),
392    }
393}
394
395fn api_function_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
396    match chat_request.legacy_function_call.as_ref() {
397        Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
398        Some(ApiFunctionCallChoice::Function { name: selected }) => {
399            selected == name
400                && chat_request
401                    .legacy_functions
402                    .iter()
403                    .any(|function| function.name == name)
404        }
405        _ => chat_request
406            .legacy_functions
407            .iter()
408            .any(|function| function.name == name),
409    }
410}
411
412fn parse_json_value_from_generated_text(text: &str) -> Option<serde_json::Value> {
413    let trimmed = strip_single_json_fence(text.trim());
414    serde_json::from_str(trimmed).ok().or_else(|| {
415        let start = trimmed.find('{')?;
416        let end = trimmed.rfind('}')?;
417        (start <= end)
418            .then(|| serde_json::from_str(&trimmed[start..=end]).ok())
419            .flatten()
420    })
421}
422
423fn strip_single_json_fence(text: &str) -> &str {
424    let Some(rest) = text.strip_prefix("```") else {
425        return text;
426    };
427    let rest = rest.strip_prefix("json").unwrap_or(rest).trim_start();
428    rest.strip_suffix("```").map(str::trim).unwrap_or(text)
429}
430
431fn api_arguments_to_string(arguments: Option<&serde_json::Value>) -> String {
432    match arguments {
433        Some(serde_json::Value::String(raw)) => raw.clone(),
434        Some(value) => serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()),
435        None => "{}".to_string(),
436    }
437}
438
439impl InferenceRequest {
440    /// Create a new inference request
441    pub fn new(prompt: impl Into<String>, model_id: impl Into<ModelId>) -> Self {
442        Self {
443            id: RequestId::new(),
444            prompt: prompt.into(),
445            model_id: model_id.into(),
446            sampling_params: SamplingParams::default(),
447            stream: false,
448            priority: Priority::default(),
449            client_id: None,
450            session_id: None,
451            created_at: Utc::now(),
452            api_request: None,
453            metadata: HashMap::new(),
454        }
455    }
456
457    /// Set sampling parameters
458    pub fn with_sampling_params(mut self, params: SamplingParams) -> Self {
459        self.sampling_params = params;
460        self
461    }
462
463    /// Enable streaming
464    pub fn with_stream(mut self, stream: bool) -> Self {
465        self.stream = stream;
466        self
467    }
468
469    /// Set priority
470    pub fn with_priority(mut self, priority: Priority) -> Self {
471        self.priority = priority;
472        self
473    }
474
475    /// Set client ID
476    pub fn with_client_id(mut self, client_id: impl Into<ClientId>) -> Self {
477        self.client_id = Some(client_id.into());
478        self
479    }
480
481    /// Set session ID
482    pub fn with_session_id(mut self, session_id: SessionId) -> Self {
483        self.session_id = Some(session_id);
484        self
485    }
486
487    /// Set structured product/API request context.
488    pub fn with_api_request(mut self, api_request: ApiRequest) -> Self {
489        self.api_request = Some(api_request);
490        self
491    }
492
493    /// Add metadata
494    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
495        self.metadata.insert(key.into(), value);
496        self
497    }
498}
499
500/// Inference response
501#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct InferenceResponse {
503    /// Request ID this response corresponds to
504    pub request_id: RequestId,
505    /// Generated text
506    pub text: String,
507    /// Generated token IDs
508    pub tokens: Vec<TokenId>,
509    /// Reason for completion
510    pub finish_reason: FinishReason,
511    /// Token usage statistics
512    pub usage: TokenUsage,
513    /// Total latency in milliseconds
514    pub latency_ms: u64,
515    /// Response creation timestamp
516    pub created_at: DateTime<Utc>,
517    /// Additional response metadata
518    pub metadata: HashMap<String, serde_json::Value>,
519    /// Structured product/API response context. Engines that can produce
520    /// product-native outputs, such as assistant tool calls, can populate
521    /// this without overloading plain text or ad hoc metadata.
522    #[serde(default, skip_serializing_if = "Option::is_none")]
523    pub api_response: Option<ApiResponse>,
524}
525
526/// Streaming response chunk
527#[derive(Debug, Clone, Serialize, Deserialize)]
528pub struct StreamChunk {
529    /// Request ID this chunk corresponds to
530    pub request_id: RequestId,
531    /// Text delta for this chunk
532    pub text: String,
533    /// Token ID for this chunk (if available)
534    pub token: Option<TokenId>,
535    /// Finish reason if this is the final chunk
536    pub finish_reason: Option<FinishReason>,
537    /// Token usage (typically only in final chunk)
538    pub usage: Option<TokenUsage>,
539    /// Chunk creation timestamp
540    pub created_at: DateTime<Utc>,
541    /// Chunk metadata
542    pub metadata: HashMap<String, serde_json::Value>,
543    /// Structured product/API response context for final streaming chunks.
544    /// This mirrors `InferenceResponse::api_response` so streaming endpoints
545    /// can return native tool/function-call payloads without reparsing text.
546    #[serde(default, skip_serializing_if = "Option::is_none")]
547    pub api_response: Option<ApiResponse>,
548}
549
550/// Batch request for processing multiple requests together
551#[derive(Debug, Clone, Serialize, Deserialize)]
552pub struct BatchRequest {
553    /// Batch identifier
554    pub batch_id: BatchId,
555    /// Requests in this batch
556    pub requests: Vec<InferenceRequest>,
557    /// Maximum sequence length for this batch
558    pub max_sequence_length: usize,
559    /// Batch creation timestamp
560    pub created_at: DateTime<Utc>,
561}
562
563impl BatchRequest {
564    /// Create a new batch request
565    pub fn new(requests: Vec<InferenceRequest>) -> Self {
566        let max_sequence_length = requests
567            .iter()
568            .map(|r| r.sampling_params.max_tokens)
569            .max()
570            .unwrap_or(512);
571
572        Self {
573            batch_id: BatchId::new(),
574            requests,
575            max_sequence_length,
576            created_at: Utc::now(),
577        }
578    }
579
580    /// Get the number of requests in this batch
581    pub fn size(&self) -> usize {
582        self.requests.len()
583    }
584
585    /// Check if batch is empty
586    pub fn is_empty(&self) -> bool {
587        self.requests.is_empty()
588    }
589}
590
591/// Request state in the scheduler
592#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
593pub enum RequestState {
594    /// Request is waiting in queue
595    Waiting,
596    /// Request is being processed
597    Running,
598    /// Request was preempted and is waiting to resume
599    Preempted,
600    /// Request completed successfully
601    Completed,
602    /// Request failed with error
603    Failed,
604    /// Request was cancelled
605    Cancelled,
606}
607
608/// Scheduled request with additional state information
609#[derive(Debug, Clone)]
610pub struct ScheduledRequest {
611    /// The original request
612    pub request: InferenceRequest,
613    /// Current state in scheduler
614    pub state: RequestState,
615    /// Allocated cache blocks
616    pub allocated_blocks: Vec<crate::BlockId>,
617    /// Number of tokens processed so far
618    pub tokens_processed: usize,
619    /// Estimated completion time
620    pub estimated_completion: Option<DateTime<Utc>>,
621}
622
623impl ScheduledRequest {
624    /// Create a new scheduled request
625    pub fn new(request: InferenceRequest) -> Self {
626        Self {
627            request,
628            state: RequestState::Waiting,
629            allocated_blocks: Vec::new(),
630            tokens_processed: 0,
631            estimated_completion: None,
632        }
633    }
634
635    /// Update request state
636    pub fn set_state(&mut self, state: RequestState) {
637        self.state = state;
638    }
639
640    /// Add allocated cache blocks
641    pub fn add_blocks(&mut self, blocks: Vec<crate::BlockId>) {
642        self.allocated_blocks.extend(blocks);
643    }
644
645    /// Update tokens processed
646    pub fn update_progress(&mut self, tokens_processed: usize) {
647        self.tokens_processed = tokens_processed;
648    }
649}