1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator;
6
7use super::UNKNOWN_MODEL_ID;
8
9pub(crate) fn default_model() -> String {
15 UNKNOWN_MODEL_ID.to_string()
16}
17
18pub fn default_true() -> bool {
20 true
21}
22
23pub trait GenerationRequest: Send + Sync {
31 fn is_stream(&self) -> bool;
33
34 fn get_model(&self) -> Option<&str>;
36
37 fn extract_text_for_routing(&self) -> String;
39}
40
41#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
47#[serde(untagged)]
48pub enum StringOrArray {
49 String(String),
50 Array(Vec<String>),
51}
52
53impl StringOrArray {
54 pub fn len(&self) -> usize {
56 match self {
57 StringOrArray::String(_) => 1,
58 StringOrArray::Array(arr) => arr.len(),
59 }
60 }
61
62 pub fn is_empty(&self) -> bool {
64 match self {
65 StringOrArray::String(s) => s.is_empty(),
66 StringOrArray::Array(arr) => arr.is_empty(),
67 }
68 }
69
70 pub fn to_vec(&self) -> Vec<String> {
72 match self {
73 StringOrArray::String(s) => vec![s.clone()],
74 StringOrArray::Array(arr) => arr.clone(),
75 }
76 }
77
78 pub fn iter(&self) -> StringOrArrayIter<'_> {
81 StringOrArrayIter {
82 inner: self,
83 index: 0,
84 }
85 }
86
87 pub fn first(&self) -> Option<&str> {
89 match self {
90 StringOrArray::String(s) => {
91 if s.is_empty() {
92 None
93 } else {
94 Some(s)
95 }
96 }
97 StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
98 }
99 }
100}
101
102pub struct StringOrArrayIter<'a> {
104 inner: &'a StringOrArray,
105 index: usize,
106}
107
108impl<'a> Iterator for StringOrArrayIter<'a> {
109 type Item = &'a str;
110
111 fn next(&mut self) -> Option<Self::Item> {
112 match self.inner {
113 StringOrArray::String(s) => {
114 if self.index == 0 {
115 self.index = 1;
116 Some(s.as_str())
117 } else {
118 None
119 }
120 }
121 StringOrArray::Array(arr) => {
122 if self.index < arr.len() {
123 let item = &arr[self.index];
124 self.index += 1;
125 Some(item.as_str())
126 } else {
127 None
128 }
129 }
130 }
131 }
132
133 fn size_hint(&self) -> (usize, Option<usize>) {
134 let remaining = match self.inner {
135 StringOrArray::String(_) => 1 - self.index,
136 StringOrArray::Array(arr) => arr.len() - self.index,
137 };
138 (remaining, Some(remaining))
139 }
140}
141
142impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {}
143
144pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
147 match stop {
148 StringOrArray::String(s) => {
149 if s.is_empty() {
150 return Err(validator::ValidationError::new(
151 "stop sequences cannot be empty",
152 ));
153 }
154 }
155 StringOrArray::Array(arr) => {
156 if arr.len() > 4 {
157 return Err(validator::ValidationError::new(
158 "maximum 4 stop sequences allowed",
159 ));
160 }
161 for s in arr {
162 if s.is_empty() {
163 return Err(validator::ValidationError::new(
164 "stop sequences cannot be empty",
165 ));
166 }
167 }
168 }
169 }
170 Ok(())
171}
172
173#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
178#[serde(tag = "type")]
179pub enum ContentPart {
180 #[serde(rename = "text")]
181 Text { text: String },
182 #[serde(rename = "image_url")]
183 ImageUrl { image_url: ImageUrl },
184 #[serde(rename = "video_url")]
185 VideoUrl { video_url: VideoUrl },
186}
187
188#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
189pub struct ImageUrl {
190 pub url: String,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 pub detail: Option<String>, }
194
195#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
196pub struct VideoUrl {
197 pub url: String,
198}
199
200#[derive(Debug, Clone, Deserialize, Serialize)]
205#[serde(tag = "type")]
206pub enum ResponseFormat {
207 #[serde(rename = "text")]
208 Text,
209 #[serde(rename = "json_object")]
210 JsonObject,
211 #[serde(rename = "json_schema")]
212 JsonSchema { json_schema: JsonSchemaFormat },
213}
214
215#[derive(Debug, Clone, Deserialize, Serialize)]
216pub struct JsonSchemaFormat {
217 pub name: String,
218 pub schema: Value,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub strict: Option<bool>,
221}
222
223#[derive(Debug, Clone, Deserialize, Serialize)]
228pub struct StreamOptions {
229 #[serde(skip_serializing_if = "Option::is_none")]
230 pub include_usage: Option<bool>,
231}
232
233#[serde_with::skip_serializing_none]
234#[derive(Debug, Clone, Deserialize, Serialize)]
235pub struct ToolCallDelta {
236 pub index: u32,
237 pub id: Option<String>,
238 #[serde(rename = "type")]
239 pub tool_type: Option<String>,
240 pub function: Option<FunctionCallDelta>,
241}
242
243#[serde_with::skip_serializing_none]
244#[derive(Debug, Clone, Deserialize, Serialize)]
245pub struct FunctionCallDelta {
246 pub name: Option<String>,
247 pub arguments: Option<String>,
248}
249
250#[derive(Debug, Clone, Deserialize, Serialize)]
256#[serde(rename_all = "snake_case")]
257pub enum ToolChoiceValue {
258 Auto,
259 Required,
260 None,
261}
262
263#[derive(Debug, Clone, Deserialize, Serialize)]
265#[serde(untagged)]
266pub enum ToolChoice {
267 Value(ToolChoiceValue),
268 Function {
269 #[serde(rename = "type")]
270 tool_type: String, function: FunctionChoice,
272 },
273 AllowedTools {
274 #[serde(rename = "type")]
275 tool_type: String, mode: String, tools: Vec<ToolReference>,
278 },
279}
280
281impl Default for ToolChoice {
282 fn default() -> Self {
283 Self::Value(ToolChoiceValue::Auto)
284 }
285}
286
287impl ToolChoice {
288 pub fn serialize_to_string(tool_choice: Option<&ToolChoice>) -> String {
292 tool_choice
293 .map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
294 .unwrap_or_else(|| "auto".to_string())
295 }
296}
297
298#[derive(Debug, Clone, Deserialize, Serialize)]
300pub struct FunctionChoice {
301 pub name: String,
302}
303
304#[derive(Debug, Clone, Deserialize, Serialize)]
309#[serde(tag = "type")]
310#[serde(rename_all = "snake_case")]
311pub enum ToolReference {
312 #[serde(rename = "function")]
314 Function { name: String },
315
316 #[serde(rename = "mcp")]
318 Mcp {
319 server_label: String,
320 #[serde(skip_serializing_if = "Option::is_none")]
321 name: Option<String>,
322 },
323
324 #[serde(rename = "file_search")]
326 FileSearch,
327
328 #[serde(rename = "web_search_preview")]
330 WebSearchPreview,
331
332 #[serde(rename = "computer_use_preview")]
334 ComputerUsePreview,
335
336 #[serde(rename = "code_interpreter")]
338 CodeInterpreter,
339
340 #[serde(rename = "image_generation")]
342 ImageGeneration,
343}
344
345impl ToolReference {
346 pub fn identifier(&self) -> String {
348 match self {
349 ToolReference::Function { name } => format!("function:{name}"),
350 ToolReference::Mcp { server_label, name } => {
351 if let Some(n) = name {
352 format!("mcp:{server_label}:{n}")
353 } else {
354 format!("mcp:{server_label}")
355 }
356 }
357 ToolReference::FileSearch => "file_search".to_string(),
358 ToolReference::WebSearchPreview => "web_search_preview".to_string(),
359 ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
360 ToolReference::CodeInterpreter => "code_interpreter".to_string(),
361 ToolReference::ImageGeneration => "image_generation".to_string(),
362 }
363 }
364
365 pub fn function_name(&self) -> Option<&str> {
367 match self {
368 ToolReference::Function { name } => Some(name.as_str()),
369 _ => None,
370 }
371 }
372}
373
374#[derive(Debug, Clone, Deserialize, Serialize)]
375pub struct Tool {
376 #[serde(rename = "type")]
377 pub tool_type: String, pub function: Function,
379}
380
381#[serde_with::skip_serializing_none]
382#[derive(Debug, Clone, Deserialize, Serialize)]
383pub struct Function {
384 pub name: String,
385 pub description: Option<String>,
386 pub parameters: Value, pub strict: Option<bool>,
389}
390
391#[derive(Debug, Clone, Deserialize, Serialize)]
392pub struct ToolCall {
393 pub id: String,
394 #[serde(rename = "type")]
395 pub tool_type: String, pub function: FunctionCallResponse,
397}
398
399#[derive(Debug, Clone)]
402pub enum FunctionCall {
403 None,
404 Auto,
405 Function { name: String },
406}
407
408impl Serialize for FunctionCall {
409 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
410 match self {
411 FunctionCall::None => serializer.serialize_str("none"),
412 FunctionCall::Auto => serializer.serialize_str("auto"),
413 FunctionCall::Function { name } => {
414 use serde::ser::SerializeMap;
415 let mut map = serializer.serialize_map(Some(1))?;
416 map.serialize_entry("name", name)?;
417 map.end()
418 }
419 }
420 }
421}
422
423impl<'de> Deserialize<'de> for FunctionCall {
424 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
425 let value = Value::deserialize(deserializer)?;
426 match &value {
427 Value::String(s) => match s.as_str() {
428 "none" => Ok(FunctionCall::None),
429 "auto" => Ok(FunctionCall::Auto),
430 other => Err(serde::de::Error::custom(format!(
431 "unknown function_call value: \"{other}\""
432 ))),
433 },
434 Value::Object(map) => {
435 if let Some(Value::String(name)) = map.get("name") {
436 Ok(FunctionCall::Function { name: name.clone() })
437 } else {
438 Err(serde::de::Error::custom(
439 "function_call object must have a \"name\" string field",
440 ))
441 }
442 }
443 _ => Err(serde::de::Error::custom(
444 "function_call must be a string or object",
445 )),
446 }
447 }
448}
449
450#[derive(Debug, Clone, Deserialize, Serialize)]
451pub struct FunctionCallResponse {
452 pub name: String,
453 #[serde(default)]
454 pub arguments: Option<String>, }
456
457#[derive(Debug, Clone, Deserialize, Serialize)]
462pub struct Usage {
463 pub prompt_tokens: u32,
464 pub completion_tokens: u32,
465 pub total_tokens: u32,
466 #[serde(skip_serializing_if = "Option::is_none")]
467 pub completion_tokens_details: Option<CompletionTokensDetails>,
468}
469
470impl Usage {
471 pub fn from_counts(prompt_tokens: u32, completion_tokens: u32) -> Self {
473 Self {
474 prompt_tokens,
475 completion_tokens,
476 total_tokens: prompt_tokens + completion_tokens,
477 completion_tokens_details: None,
478 }
479 }
480
481 pub fn with_reasoning_tokens(mut self, reasoning_tokens: u32) -> Self {
483 if reasoning_tokens > 0 {
484 self.completion_tokens_details = Some(CompletionTokensDetails {
485 reasoning_tokens: Some(reasoning_tokens),
486 });
487 }
488 self
489 }
490}
491
492#[derive(Debug, Clone, Deserialize, Serialize)]
493pub struct CompletionTokensDetails {
494 pub reasoning_tokens: Option<u32>,
495}
496
497#[serde_with::skip_serializing_none]
499#[derive(Debug, Clone, Deserialize, Serialize)]
500pub struct UsageInfo {
501 pub prompt_tokens: u32,
502 pub completion_tokens: u32,
503 pub total_tokens: u32,
504 pub reasoning_tokens: Option<u32>,
505 pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
506}
507
508#[derive(Debug, Clone, Deserialize, Serialize)]
509pub struct PromptTokenUsageInfo {
510 pub cached_tokens: u32,
511}
512
513#[derive(Debug, Clone, Deserialize, Serialize)]
514pub struct LogProbs {
515 pub tokens: Vec<String>,
516 pub token_logprobs: Vec<Option<f32>>,
517 pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
518 pub text_offset: Vec<u32>,
519}
520
521#[derive(Debug, Clone, Deserialize, Serialize)]
522#[serde(untagged)]
523pub enum ChatLogProbs {
524 Detailed {
525 #[serde(skip_serializing_if = "Option::is_none")]
526 content: Option<Vec<ChatLogProbsContent>>,
527 },
528 Raw(Value),
529}
530
531#[derive(Debug, Clone, Deserialize, Serialize)]
532pub struct ChatLogProbsContent {
533 pub token: String,
534 pub logprob: f32,
535 pub bytes: Option<Vec<u8>>,
536 pub top_logprobs: Vec<TopLogProb>,
537}
538
539#[derive(Debug, Clone, Deserialize, Serialize)]
540pub struct TopLogProb {
541 pub token: String,
542 pub logprob: f32,
543 pub bytes: Option<Vec<u8>>,
544}
545
546#[derive(Debug, Clone, Deserialize, Serialize)]
551pub struct ErrorResponse {
552 pub error: ErrorDetail,
553}
554
555#[serde_with::skip_serializing_none]
556#[derive(Debug, Clone, Deserialize, Serialize)]
557pub struct ErrorDetail {
558 pub message: String,
559 #[serde(rename = "type")]
560 pub error_type: String,
561 pub param: Option<String>,
562 pub code: Option<String>,
563}
564
565#[derive(Debug, Clone, Deserialize, Serialize)]
570#[serde(untagged)]
571pub enum InputIds {
572 Single(Vec<i32>),
573 Batch(Vec<Vec<i32>>),
574}
575
576#[derive(Debug, Clone, Deserialize, Serialize)]
578#[serde(untagged)]
579pub enum LoRAPath {
580 Single(Option<String>),
581 Batch(Vec<Option<String>>),
582}
583
584#[derive(Clone, Serialize, Deserialize)]
588pub struct Redacted(pub String);
589
590impl std::fmt::Debug for Redacted {
591 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592 f.write_str("[REDACTED]")
593 }
594}
595
596#[serde_with::skip_serializing_none]
602#[derive(Debug, Clone, Serialize, Deserialize)]
603pub struct ResponsePrompt {
604 pub id: String,
605 pub variables: Option<HashMap<String, PromptVariable>>,
606 pub version: Option<String>,
607}
608
609#[derive(Debug, Clone, Serialize, Deserialize)]
614#[serde(untagged)]
615pub enum PromptVariable {
616 String(String),
617 Typed(PromptVariableTyped),
618}
619
620#[serde_with::skip_serializing_none]
622#[derive(Debug, Clone, Serialize, Deserialize)]
623#[serde(tag = "type")]
624#[expect(
625 clippy::enum_variant_names,
626 reason = "variant names match OpenAI API spec"
627)]
628pub enum PromptVariableTyped {
629 #[serde(rename = "input_text")]
630 ResponseInputText { text: String },
631 #[serde(rename = "input_image")]
632 ResponseInputImage {
633 detail: Option<Detail>,
634 file_id: Option<String>,
635 image_url: Option<String>,
636 },
637 #[serde(rename = "input_file")]
638 ResponseInputFile {
639 file_data: Option<String>,
640 file_id: Option<String>,
641 file_url: Option<String>,
642 filename: Option<String>,
643 },
644}
645
646#[derive(Debug, Clone, Serialize, Deserialize, Default)]
648#[serde(rename_all = "snake_case")]
649pub enum Detail {
650 Low,
651 High,
652 #[default]
653 Auto,
654}