1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator;
6
7use super::UNKNOWN_MODEL_ID;
8
9pub(crate) fn default_model() -> String {
15 UNKNOWN_MODEL_ID.to_string()
16}
17
18pub fn default_true() -> bool {
20 true
21}
22
23pub trait GenerationRequest: Send + Sync {
31 fn is_stream(&self) -> bool;
33
34 fn get_model(&self) -> Option<&str>;
36
37 fn extract_text_for_routing(&self) -> String;
39}
40
41#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
47#[serde(untagged)]
48pub enum StringOrArray {
49 String(String),
50 Array(Vec<String>),
51}
52
53impl StringOrArray {
54 pub fn len(&self) -> usize {
56 match self {
57 StringOrArray::String(_) => 1,
58 StringOrArray::Array(arr) => arr.len(),
59 }
60 }
61
62 pub fn is_empty(&self) -> bool {
64 match self {
65 StringOrArray::String(s) => s.is_empty(),
66 StringOrArray::Array(arr) => arr.is_empty(),
67 }
68 }
69
70 pub fn to_vec(&self) -> Vec<String> {
72 match self {
73 StringOrArray::String(s) => vec![s.clone()],
74 StringOrArray::Array(arr) => arr.clone(),
75 }
76 }
77
78 pub fn iter(&self) -> StringOrArrayIter<'_> {
81 StringOrArrayIter {
82 inner: self,
83 index: 0,
84 }
85 }
86
87 pub fn first(&self) -> Option<&str> {
89 match self {
90 StringOrArray::String(s) => {
91 if s.is_empty() {
92 None
93 } else {
94 Some(s)
95 }
96 }
97 StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
98 }
99 }
100}
101
102pub struct StringOrArrayIter<'a> {
104 inner: &'a StringOrArray,
105 index: usize,
106}
107
108impl<'a> Iterator for StringOrArrayIter<'a> {
109 type Item = &'a str;
110
111 fn next(&mut self) -> Option<Self::Item> {
112 match self.inner {
113 StringOrArray::String(s) => {
114 if self.index == 0 {
115 self.index = 1;
116 Some(s.as_str())
117 } else {
118 None
119 }
120 }
121 StringOrArray::Array(arr) => {
122 if self.index < arr.len() {
123 let item = &arr[self.index];
124 self.index += 1;
125 Some(item.as_str())
126 } else {
127 None
128 }
129 }
130 }
131 }
132
133 fn size_hint(&self) -> (usize, Option<usize>) {
134 let remaining = match self.inner {
135 StringOrArray::String(_) => 1 - self.index,
136 StringOrArray::Array(arr) => arr.len() - self.index,
137 };
138 (remaining, Some(remaining))
139 }
140}
141
142impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {}
143
144pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
147 match stop {
148 StringOrArray::String(s) => {
149 if s.is_empty() {
150 return Err(validator::ValidationError::new(
151 "stop sequences cannot be empty",
152 ));
153 }
154 }
155 StringOrArray::Array(arr) => {
156 if arr.len() > 4 {
157 return Err(validator::ValidationError::new(
158 "maximum 4 stop sequences allowed",
159 ));
160 }
161 for s in arr {
162 if s.is_empty() {
163 return Err(validator::ValidationError::new(
164 "stop sequences cannot be empty",
165 ));
166 }
167 }
168 }
169 }
170 Ok(())
171}
172
173#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
178#[serde(tag = "type")]
179pub enum ContentPart {
180 #[serde(rename = "text")]
181 Text { text: String },
182 #[serde(rename = "image_url")]
183 ImageUrl { image_url: ImageUrl },
184 #[serde(rename = "video_url")]
185 VideoUrl { video_url: VideoUrl },
186}
187
188#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
189pub struct ImageUrl {
190 pub url: String,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 pub detail: Option<String>, }
194
195#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
196pub struct VideoUrl {
197 pub url: String,
198}
199
200#[derive(Debug, Clone, Deserialize, Serialize)]
205#[serde(tag = "type")]
206pub enum ResponseFormat {
207 #[serde(rename = "text")]
208 Text,
209 #[serde(rename = "json_object")]
210 JsonObject,
211 #[serde(rename = "json_schema")]
212 JsonSchema { json_schema: JsonSchemaFormat },
213}
214
215#[derive(Debug, Clone, Deserialize, Serialize)]
216pub struct JsonSchemaFormat {
217 pub name: String,
218 pub schema: Value,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub strict: Option<bool>,
221}
222
223#[derive(Debug, Clone, Deserialize, Serialize)]
228pub struct StreamOptions {
229 #[serde(skip_serializing_if = "Option::is_none")]
230 pub include_usage: Option<bool>,
231}
232
233#[derive(Debug, Clone, Deserialize, Serialize)]
234pub struct ToolCallDelta {
235 pub index: u32,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 pub id: Option<String>,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 #[serde(rename = "type")]
240 pub tool_type: Option<String>,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 pub function: Option<FunctionCallDelta>,
243}
244
245#[derive(Debug, Clone, Deserialize, Serialize)]
246pub struct FunctionCallDelta {
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub name: Option<String>,
249 #[serde(skip_serializing_if = "Option::is_none")]
250 pub arguments: Option<String>,
251}
252
253#[derive(Debug, Clone, Deserialize, Serialize)]
259#[serde(rename_all = "snake_case")]
260pub enum ToolChoiceValue {
261 Auto,
262 Required,
263 None,
264}
265
266#[derive(Debug, Clone, Deserialize, Serialize)]
268#[serde(untagged)]
269pub enum ToolChoice {
270 Value(ToolChoiceValue),
271 Function {
272 #[serde(rename = "type")]
273 tool_type: String, function: FunctionChoice,
275 },
276 AllowedTools {
277 #[serde(rename = "type")]
278 tool_type: String, mode: String, tools: Vec<ToolReference>,
281 },
282}
283
284impl Default for ToolChoice {
285 fn default() -> Self {
286 Self::Value(ToolChoiceValue::Auto)
287 }
288}
289
290impl ToolChoice {
291 pub fn serialize_to_string(tool_choice: &Option<ToolChoice>) -> String {
295 tool_choice
296 .as_ref()
297 .map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
298 .unwrap_or_else(|| "auto".to_string())
299 }
300}
301
302#[derive(Debug, Clone, Deserialize, Serialize)]
304pub struct FunctionChoice {
305 pub name: String,
306}
307
308#[derive(Debug, Clone, Deserialize, Serialize)]
313#[serde(tag = "type")]
314#[serde(rename_all = "snake_case")]
315pub enum ToolReference {
316 #[serde(rename = "function")]
318 Function { name: String },
319
320 #[serde(rename = "mcp")]
322 Mcp {
323 server_label: String,
324 #[serde(skip_serializing_if = "Option::is_none")]
325 name: Option<String>,
326 },
327
328 #[serde(rename = "file_search")]
330 FileSearch,
331
332 #[serde(rename = "web_search_preview")]
334 WebSearchPreview,
335
336 #[serde(rename = "computer_use_preview")]
338 ComputerUsePreview,
339
340 #[serde(rename = "code_interpreter")]
342 CodeInterpreter,
343
344 #[serde(rename = "image_generation")]
346 ImageGeneration,
347}
348
349impl ToolReference {
350 pub fn identifier(&self) -> String {
352 match self {
353 ToolReference::Function { name } => format!("function:{}", name),
354 ToolReference::Mcp { server_label, name } => {
355 if let Some(n) = name {
356 format!("mcp:{}:{}", server_label, n)
357 } else {
358 format!("mcp:{}", server_label)
359 }
360 }
361 ToolReference::FileSearch => "file_search".to_string(),
362 ToolReference::WebSearchPreview => "web_search_preview".to_string(),
363 ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
364 ToolReference::CodeInterpreter => "code_interpreter".to_string(),
365 ToolReference::ImageGeneration => "image_generation".to_string(),
366 }
367 }
368
369 pub fn function_name(&self) -> Option<&str> {
371 match self {
372 ToolReference::Function { name } => Some(name.as_str()),
373 _ => None,
374 }
375 }
376}
377
378#[derive(Debug, Clone, Deserialize, Serialize)]
379pub struct Tool {
380 #[serde(rename = "type")]
381 pub tool_type: String, pub function: Function,
383}
384
385#[derive(Debug, Clone, Deserialize, Serialize)]
386pub struct Function {
387 pub name: String,
388 #[serde(skip_serializing_if = "Option::is_none")]
389 pub description: Option<String>,
390 pub parameters: Value, #[serde(skip_serializing_if = "Option::is_none")]
393 pub strict: Option<bool>,
394}
395
396#[derive(Debug, Clone, Deserialize, Serialize)]
397pub struct ToolCall {
398 pub id: String,
399 #[serde(rename = "type")]
400 pub tool_type: String, pub function: FunctionCallResponse,
402}
403
404#[derive(Debug, Clone, Deserialize, Serialize)]
405#[serde(untagged)]
406pub enum FunctionCall {
407 None,
408 Auto,
409 Function { name: String },
410}
411
412#[derive(Debug, Clone, Deserialize, Serialize)]
413pub struct FunctionCallResponse {
414 pub name: String,
415 #[serde(default)]
416 pub arguments: Option<String>, }
418
419#[derive(Debug, Clone, Deserialize, Serialize)]
424pub struct Usage {
425 pub prompt_tokens: u32,
426 pub completion_tokens: u32,
427 pub total_tokens: u32,
428 #[serde(skip_serializing_if = "Option::is_none")]
429 pub completion_tokens_details: Option<CompletionTokensDetails>,
430}
431
432#[derive(Debug, Clone, Deserialize, Serialize)]
433pub struct CompletionTokensDetails {
434 pub reasoning_tokens: Option<u32>,
435}
436
437#[derive(Debug, Clone, Deserialize, Serialize)]
439pub struct UsageInfo {
440 pub prompt_tokens: u32,
441 pub completion_tokens: u32,
442 pub total_tokens: u32,
443 #[serde(skip_serializing_if = "Option::is_none")]
444 pub reasoning_tokens: Option<u32>,
445 #[serde(skip_serializing_if = "Option::is_none")]
446 pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
447}
448
449#[derive(Debug, Clone, Deserialize, Serialize)]
450pub struct PromptTokenUsageInfo {
451 pub cached_tokens: u32,
452}
453
454#[derive(Debug, Clone, Deserialize, Serialize)]
455pub struct LogProbs {
456 pub tokens: Vec<String>,
457 pub token_logprobs: Vec<Option<f32>>,
458 pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
459 pub text_offset: Vec<u32>,
460}
461
462#[derive(Debug, Clone, Deserialize, Serialize)]
463#[serde(untagged)]
464pub enum ChatLogProbs {
465 Detailed {
466 #[serde(skip_serializing_if = "Option::is_none")]
467 content: Option<Vec<ChatLogProbsContent>>,
468 },
469 Raw(Value),
470}
471
472#[derive(Debug, Clone, Deserialize, Serialize)]
473pub struct ChatLogProbsContent {
474 pub token: String,
475 pub logprob: f32,
476 pub bytes: Option<Vec<u8>>,
477 pub top_logprobs: Vec<TopLogProb>,
478}
479
480#[derive(Debug, Clone, Deserialize, Serialize)]
481pub struct TopLogProb {
482 pub token: String,
483 pub logprob: f32,
484 pub bytes: Option<Vec<u8>>,
485}
486
487#[derive(Debug, Clone, Deserialize, Serialize)]
492pub struct ErrorResponse {
493 pub error: ErrorDetail,
494}
495
496#[derive(Debug, Clone, Deserialize, Serialize)]
497pub struct ErrorDetail {
498 pub message: String,
499 #[serde(rename = "type")]
500 pub error_type: String,
501 #[serde(skip_serializing_if = "Option::is_none")]
502 pub param: Option<String>,
503 #[serde(skip_serializing_if = "Option::is_none")]
504 pub code: Option<String>,
505}
506
507#[derive(Debug, Clone, Deserialize, Serialize)]
512#[serde(untagged)]
513pub enum InputIds {
514 Single(Vec<i32>),
515 Batch(Vec<Vec<i32>>),
516}
517
518#[derive(Debug, Clone, Deserialize, Serialize)]
520#[serde(untagged)]
521pub enum LoRAPath {
522 Single(Option<String>),
523 Batch(Vec<Option<String>>),
524}