Skip to main content

ferrum_types/
requests.rs

1//! Request and response types for inference
2
3use crate::{ids::*, models::TokenUsage, FinishReason, Priority, SamplingParams, TokenId};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8pub const PROMPT_TOKENS_METADATA_KEY: &str = "ferrum_prompt_tokens";
9
10/// Inference request
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct InferenceRequest {
13    /// Unique request identifier
14    pub id: RequestId,
15    /// Input prompt text
16    pub prompt: String,
17    /// Model to use for inference
18    pub model_id: ModelId,
19    /// Sampling parameters
20    pub sampling_params: SamplingParams,
21    /// Whether to stream response
22    pub stream: bool,
23    /// Request priority
24    pub priority: Priority,
25    /// Client identifier
26    pub client_id: Option<ClientId>,
27    /// Session identifier for stateful interactions
28    pub session_id: Option<SessionId>,
29    /// Request creation timestamp
30    pub created_at: DateTime<Utc>,
31    /// Structured product/API request context. `prompt` remains the rendered
32    /// model input for current engines; this carries the original semantic
33    /// request boundary for API features such as tools and response formats.
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub api_request: Option<ApiRequest>,
36    /// Additional metadata
37    pub metadata: HashMap<String, serde_json::Value>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
41#[serde(tag = "kind", rename_all = "snake_case")]
42pub enum ApiRequest {
43    Chat(ApiChatRequest),
44    Completion(ApiCompletionRequest),
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
48#[serde(tag = "kind", rename_all = "snake_case")]
49pub enum ApiResponse {
50    Chat(ApiChatResponse),
51    Completion(ApiCompletionResponse),
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
55pub struct ApiChatRequest {
56    pub messages: Vec<ApiChatMessage>,
57    #[serde(default, skip_serializing_if = "Vec::is_empty")]
58    pub tools: Vec<ApiTool>,
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub tool_choice: Option<ApiToolChoice>,
61    #[serde(default, skip_serializing_if = "Vec::is_empty")]
62    pub legacy_functions: Vec<ApiFunction>,
63    #[serde(default, skip_serializing_if = "Option::is_none")]
64    pub legacy_function_call: Option<ApiFunctionCallChoice>,
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub response_format: Option<ApiResponseFormat>,
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub stream_options: Option<ApiStreamOptions>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub struct ApiCompletionRequest {
73    pub prompt: String,
74    #[serde(default, skip_serializing_if = "Option::is_none")]
75    pub response_format: Option<ApiResponseFormat>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79pub struct ApiChatResponse {
80    pub message: ApiChatMessage,
81    #[serde(default, skip_serializing_if = "Option::is_none")]
82    pub finish_reason: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub struct ApiCompletionResponse {
87    pub text: String,
88    #[serde(default, skip_serializing_if = "Option::is_none")]
89    pub finish_reason: Option<String>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93pub struct ApiChatMessage {
94    pub role: ApiMessageRole,
95    pub content: String,
96    #[serde(default, skip_serializing_if = "Option::is_none")]
97    pub name: Option<String>,
98    #[serde(default, skip_serializing_if = "Vec::is_empty")]
99    pub tool_calls: Vec<ApiToolCall>,
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    pub tool_call_id: Option<String>,
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub function_call: Option<ApiFunctionCall>,
104}
105
106#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
107#[serde(rename_all = "lowercase")]
108pub enum ApiMessageRole {
109    System,
110    User,
111    Assistant,
112    Function,
113    Tool,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
117pub struct ApiTool {
118    #[serde(rename = "type")]
119    pub tool_type: String,
120    pub function: ApiFunction,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
124pub struct ApiFunction {
125    pub name: String,
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub description: Option<String>,
128    #[serde(default, skip_serializing_if = "Option::is_none")]
129    pub parameters: Option<serde_json::Value>,
130    #[serde(default, skip_serializing_if = "Option::is_none")]
131    pub strict: Option<bool>,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
135#[serde(untagged)]
136pub enum ApiToolChoice {
137    Mode(String),
138    Function {
139        #[serde(rename = "type")]
140        tool_type: String,
141        function: ApiToolChoiceFunction,
142    },
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
146pub struct ApiToolChoiceFunction {
147    pub name: String,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
151#[serde(untagged)]
152pub enum ApiFunctionCallChoice {
153    Mode(String),
154    Function { name: String },
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
158pub struct ApiToolCall {
159    pub id: String,
160    #[serde(rename = "type")]
161    pub tool_type: String,
162    pub function: ApiFunctionCall,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
166pub struct ApiFunctionCall {
167    pub name: String,
168    pub arguments: String,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
172pub struct ApiResponseFormat {
173    #[serde(rename = "type")]
174    pub format_type: String,
175    #[serde(default, skip_serializing_if = "Option::is_none")]
176    pub json_schema: Option<ApiJsonSchema>,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
180pub struct ApiJsonSchema {
181    #[serde(default, skip_serializing_if = "Option::is_none")]
182    pub name: Option<String>,
183    pub schema: serde_json::Value,
184    #[serde(default, skip_serializing_if = "Option::is_none")]
185    pub strict: Option<bool>,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
189pub struct ApiStreamOptions {
190    #[serde(default, skip_serializing_if = "Option::is_none")]
191    pub include_usage: Option<bool>,
192}
193
194pub fn api_response_from_generated_text(
195    request: &InferenceRequest,
196    text: &str,
197) -> Option<ApiResponse> {
198    let ApiRequest::Chat(chat_request) = request.api_request.as_ref()? else {
199        return None;
200    };
201    chat_api_response_from_generated_text(chat_request, text).map(ApiResponse::Chat)
202}
203
204pub fn chat_api_may_emit_tool_or_function_call(chat_request: &ApiChatRequest) -> bool {
205    (!chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request))
206        || (!chat_request.legacy_functions.is_empty()
207            && !api_function_call_choice_is_none(chat_request))
208}
209
210pub fn chat_api_response_from_generated_text(
211    chat_request: &ApiChatRequest,
212    text: &str,
213) -> Option<ApiChatResponse> {
214    if !chat_request.tools.is_empty() && !api_tool_choice_is_none(chat_request) {
215        if let Some(tool_calls) = parse_tool_calls_from_generated_text(text, chat_request) {
216            return Some(ApiChatResponse {
217                message: ApiChatMessage {
218                    role: ApiMessageRole::Assistant,
219                    content: String::new(),
220                    name: None,
221                    tool_calls,
222                    tool_call_id: None,
223                    function_call: None,
224                },
225                finish_reason: Some("tool_calls".to_string()),
226            });
227        }
228    }
229
230    if !chat_request.legacy_functions.is_empty() && !api_function_call_choice_is_none(chat_request)
231    {
232        if let Some(function_call) =
233            parse_legacy_function_call_from_generated_text(text, chat_request)
234        {
235            return Some(ApiChatResponse {
236                message: ApiChatMessage {
237                    role: ApiMessageRole::Assistant,
238                    content: String::new(),
239                    name: None,
240                    tool_calls: Vec::new(),
241                    tool_call_id: None,
242                    function_call: Some(function_call),
243                },
244                finish_reason: Some("function_call".to_string()),
245            });
246        }
247    }
248
249    None
250}
251
252fn api_tool_choice_is_none(chat_request: &ApiChatRequest) -> bool {
253    matches!(
254        chat_request.tool_choice.as_ref(),
255        Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
256    )
257}
258
259fn api_function_call_choice_is_none(chat_request: &ApiChatRequest) -> bool {
260    matches!(
261        chat_request.legacy_function_call.as_ref(),
262        Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none")
263    )
264}
265
266fn parse_tool_calls_from_generated_text(
267    text: &str,
268    chat_request: &ApiChatRequest,
269) -> Option<Vec<ApiToolCall>> {
270    let value = parse_json_value_from_generated_text(text)?;
271    if let Some(calls) = value.get("tool_calls").and_then(|value| value.as_array()) {
272        let parsed = calls
273            .iter()
274            .enumerate()
275            .filter_map(|(index, value)| parse_tool_call_value(value, index, chat_request))
276            .collect::<Vec<_>>();
277        return (!parsed.is_empty()).then_some(parsed);
278    }
279    if let Some(tool_call) = value.get("tool_call") {
280        return parse_tool_call_value(tool_call, 0, chat_request).map(|call| vec![call]);
281    }
282    parse_tool_call_value(&value, 0, chat_request).map(|call| vec![call])
283}
284
285fn parse_tool_call_value(
286    value: &serde_json::Value,
287    index: usize,
288    chat_request: &ApiChatRequest,
289) -> Option<ApiToolCall> {
290    let tool_type = value
291        .get("type")
292        .and_then(|value| value.as_str())
293        .unwrap_or("function");
294    if tool_type != "function" {
295        return None;
296    }
297    let function = value.get("function").unwrap_or(value);
298    let name = function.get("name").and_then(|value| value.as_str())?;
299    if !api_tool_name_allowed(chat_request, name) {
300        return None;
301    }
302    let arguments = api_arguments_to_string(function.get("arguments"));
303    let id = value
304        .get("id")
305        .and_then(|value| value.as_str())
306        .map(str::to_string)
307        .unwrap_or_else(|| format!("call_{index}"));
308
309    Some(ApiToolCall {
310        id,
311        tool_type: "function".to_string(),
312        function: ApiFunctionCall {
313            name: name.to_string(),
314            arguments,
315        },
316    })
317}
318
319fn parse_legacy_function_call_from_generated_text(
320    text: &str,
321    chat_request: &ApiChatRequest,
322) -> Option<ApiFunctionCall> {
323    let value = parse_json_value_from_generated_text(text)?;
324    let function = value.get("function_call").unwrap_or(&value);
325    let name = function.get("name").and_then(|value| value.as_str())?;
326    if !api_function_name_allowed(chat_request, name) {
327        return None;
328    }
329    Some(ApiFunctionCall {
330        name: name.to_string(),
331        arguments: api_arguments_to_string(function.get("arguments")),
332    })
333}
334
335fn api_tool_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
336    match chat_request.tool_choice.as_ref() {
337        Some(ApiToolChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
338        Some(ApiToolChoice::Function {
339            tool_type,
340            function,
341        }) => {
342            tool_type == "function"
343                && function.name == name
344                && chat_request
345                    .tools
346                    .iter()
347                    .any(|tool| tool.function.name == name)
348        }
349        _ => chat_request
350            .tools
351            .iter()
352            .any(|tool| tool.function.name == name),
353    }
354}
355
356fn api_function_name_allowed(chat_request: &ApiChatRequest, name: &str) -> bool {
357    match chat_request.legacy_function_call.as_ref() {
358        Some(ApiFunctionCallChoice::Mode(mode)) if mode.eq_ignore_ascii_case("none") => false,
359        Some(ApiFunctionCallChoice::Function { name: selected }) => {
360            selected == name
361                && chat_request
362                    .legacy_functions
363                    .iter()
364                    .any(|function| function.name == name)
365        }
366        _ => chat_request
367            .legacy_functions
368            .iter()
369            .any(|function| function.name == name),
370    }
371}
372
373fn parse_json_value_from_generated_text(text: &str) -> Option<serde_json::Value> {
374    let trimmed = strip_single_json_fence(text.trim());
375    serde_json::from_str(trimmed).ok().or_else(|| {
376        let start = trimmed.find('{')?;
377        let end = trimmed.rfind('}')?;
378        (start <= end)
379            .then(|| serde_json::from_str(&trimmed[start..=end]).ok())
380            .flatten()
381    })
382}
383
384fn strip_single_json_fence(text: &str) -> &str {
385    let Some(rest) = text.strip_prefix("```") else {
386        return text;
387    };
388    let rest = rest.strip_prefix("json").unwrap_or(rest).trim_start();
389    rest.strip_suffix("```").map(str::trim).unwrap_or(text)
390}
391
392fn api_arguments_to_string(arguments: Option<&serde_json::Value>) -> String {
393    match arguments {
394        Some(serde_json::Value::String(raw)) => raw.clone(),
395        Some(value) => serde_json::to_string(value).unwrap_or_else(|_| "{}".to_string()),
396        None => "{}".to_string(),
397    }
398}
399
400impl InferenceRequest {
401    /// Create a new inference request
402    pub fn new(prompt: impl Into<String>, model_id: impl Into<ModelId>) -> Self {
403        Self {
404            id: RequestId::new(),
405            prompt: prompt.into(),
406            model_id: model_id.into(),
407            sampling_params: SamplingParams::default(),
408            stream: false,
409            priority: Priority::default(),
410            client_id: None,
411            session_id: None,
412            created_at: Utc::now(),
413            api_request: None,
414            metadata: HashMap::new(),
415        }
416    }
417
418    /// Set sampling parameters
419    pub fn with_sampling_params(mut self, params: SamplingParams) -> Self {
420        self.sampling_params = params;
421        self
422    }
423
424    /// Enable streaming
425    pub fn with_stream(mut self, stream: bool) -> Self {
426        self.stream = stream;
427        self
428    }
429
430    /// Set priority
431    pub fn with_priority(mut self, priority: Priority) -> Self {
432        self.priority = priority;
433        self
434    }
435
436    /// Set client ID
437    pub fn with_client_id(mut self, client_id: impl Into<ClientId>) -> Self {
438        self.client_id = Some(client_id.into());
439        self
440    }
441
442    /// Set session ID
443    pub fn with_session_id(mut self, session_id: SessionId) -> Self {
444        self.session_id = Some(session_id);
445        self
446    }
447
448    /// Set structured product/API request context.
449    pub fn with_api_request(mut self, api_request: ApiRequest) -> Self {
450        self.api_request = Some(api_request);
451        self
452    }
453
454    /// Add metadata
455    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
456        self.metadata.insert(key.into(), value);
457        self
458    }
459}
460
461/// Inference response
462#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct InferenceResponse {
464    /// Request ID this response corresponds to
465    pub request_id: RequestId,
466    /// Generated text
467    pub text: String,
468    /// Generated token IDs
469    pub tokens: Vec<TokenId>,
470    /// Reason for completion
471    pub finish_reason: FinishReason,
472    /// Token usage statistics
473    pub usage: TokenUsage,
474    /// Total latency in milliseconds
475    pub latency_ms: u64,
476    /// Response creation timestamp
477    pub created_at: DateTime<Utc>,
478    /// Additional response metadata
479    pub metadata: HashMap<String, serde_json::Value>,
480    /// Structured product/API response context. Engines that can produce
481    /// product-native outputs, such as assistant tool calls, can populate
482    /// this without overloading plain text or ad hoc metadata.
483    #[serde(default, skip_serializing_if = "Option::is_none")]
484    pub api_response: Option<ApiResponse>,
485}
486
487/// Streaming response chunk
488#[derive(Debug, Clone, Serialize, Deserialize)]
489pub struct StreamChunk {
490    /// Request ID this chunk corresponds to
491    pub request_id: RequestId,
492    /// Text delta for this chunk
493    pub text: String,
494    /// Token ID for this chunk (if available)
495    pub token: Option<TokenId>,
496    /// Finish reason if this is the final chunk
497    pub finish_reason: Option<FinishReason>,
498    /// Token usage (typically only in final chunk)
499    pub usage: Option<TokenUsage>,
500    /// Chunk creation timestamp
501    pub created_at: DateTime<Utc>,
502    /// Chunk metadata
503    pub metadata: HashMap<String, serde_json::Value>,
504    /// Structured product/API response context for final streaming chunks.
505    /// This mirrors `InferenceResponse::api_response` so streaming endpoints
506    /// can return native tool/function-call payloads without reparsing text.
507    #[serde(default, skip_serializing_if = "Option::is_none")]
508    pub api_response: Option<ApiResponse>,
509}
510
511/// Batch request for processing multiple requests together
512#[derive(Debug, Clone, Serialize, Deserialize)]
513pub struct BatchRequest {
514    /// Batch identifier
515    pub batch_id: BatchId,
516    /// Requests in this batch
517    pub requests: Vec<InferenceRequest>,
518    /// Maximum sequence length for this batch
519    pub max_sequence_length: usize,
520    /// Batch creation timestamp
521    pub created_at: DateTime<Utc>,
522}
523
524impl BatchRequest {
525    /// Create a new batch request
526    pub fn new(requests: Vec<InferenceRequest>) -> Self {
527        let max_sequence_length = requests
528            .iter()
529            .map(|r| r.sampling_params.max_tokens)
530            .max()
531            .unwrap_or(512);
532
533        Self {
534            batch_id: BatchId::new(),
535            requests,
536            max_sequence_length,
537            created_at: Utc::now(),
538        }
539    }
540
541    /// Get the number of requests in this batch
542    pub fn size(&self) -> usize {
543        self.requests.len()
544    }
545
546    /// Check if batch is empty
547    pub fn is_empty(&self) -> bool {
548        self.requests.is_empty()
549    }
550}
551
552/// Request state in the scheduler
553#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
554pub enum RequestState {
555    /// Request is waiting in queue
556    Waiting,
557    /// Request is being processed
558    Running,
559    /// Request was preempted and is waiting to resume
560    Preempted,
561    /// Request completed successfully
562    Completed,
563    /// Request failed with error
564    Failed,
565    /// Request was cancelled
566    Cancelled,
567}
568
569/// Scheduled request with additional state information
570#[derive(Debug, Clone)]
571pub struct ScheduledRequest {
572    /// The original request
573    pub request: InferenceRequest,
574    /// Current state in scheduler
575    pub state: RequestState,
576    /// Allocated cache blocks
577    pub allocated_blocks: Vec<crate::BlockId>,
578    /// Number of tokens processed so far
579    pub tokens_processed: usize,
580    /// Estimated completion time
581    pub estimated_completion: Option<DateTime<Utc>>,
582}
583
584impl ScheduledRequest {
585    /// Create a new scheduled request
586    pub fn new(request: InferenceRequest) -> Self {
587        Self {
588            request,
589            state: RequestState::Waiting,
590            allocated_blocks: Vec::new(),
591            tokens_processed: 0,
592            estimated_completion: None,
593        }
594    }
595
596    /// Update request state
597    pub fn set_state(&mut self, state: RequestState) {
598        self.state = state;
599    }
600
601    /// Add allocated cache blocks
602    pub fn add_blocks(&mut self, blocks: Vec<crate::BlockId>) {
603        self.allocated_blocks.extend(blocks);
604    }
605
606    /// Update tokens processed
607    pub fn update_progress(&mut self, tokens_processed: usize) {
608        self.tokens_processed = tokens_processed;
609    }
610}