use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use validator;
pub fn default_unknown_model() -> String {
super::UNKNOWN_MODEL_ID.to_string()
}
pub fn default_true() -> bool {
true
}
pub fn deserialize_null_as_false<'de, D>(deserializer: D) -> Result<bool, D::Error>
where
D: serde::Deserializer<'de>,
{
Option::<bool>::deserialize(deserializer).map(|opt| opt.unwrap_or(false))
}
pub trait GenerationRequest: Send + Sync {
fn is_stream(&self) -> bool;
fn get_model(&self) -> Option<&str>;
fn extract_text_for_routing(&self) -> String;
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, schemars::JsonSchema)]
#[serde(untagged)]
pub enum StringOrArray {
String(String),
Array(Vec<String>),
}
impl StringOrArray {
pub fn len(&self) -> usize {
match self {
StringOrArray::String(_) => 1,
StringOrArray::Array(arr) => arr.len(),
}
}
pub fn is_empty(&self) -> bool {
match self {
StringOrArray::String(s) => s.is_empty(),
StringOrArray::Array(arr) => arr.is_empty(),
}
}
pub fn to_vec(&self) -> Vec<String> {
match self {
StringOrArray::String(s) => vec![s.clone()],
StringOrArray::Array(arr) => arr.clone(),
}
}
pub fn iter(&self) -> StringOrArrayIter<'_> {
StringOrArrayIter {
inner: self,
index: 0,
}
}
pub fn first(&self) -> Option<&str> {
match self {
StringOrArray::String(s) => {
if s.is_empty() {
None
} else {
Some(s)
}
}
StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
}
}
}
pub struct StringOrArrayIter<'a> {
inner: &'a StringOrArray,
index: usize,
}
impl<'a> Iterator for StringOrArrayIter<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
match self.inner {
StringOrArray::String(s) => {
if self.index == 0 {
self.index = 1;
Some(s.as_str())
} else {
None
}
}
StringOrArray::Array(arr) => {
if self.index < arr.len() {
let item = &arr[self.index];
self.index += 1;
Some(item.as_str())
} else {
None
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = match self.inner {
StringOrArray::String(_) => 1 - self.index,
StringOrArray::Array(arr) => arr.len() - self.index,
};
(remaining, Some(remaining))
}
}
impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {}
pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
match stop {
StringOrArray::String(s) => {
if s.is_empty() {
return Err(validator::ValidationError::new(
"stop sequences cannot be empty",
));
}
}
StringOrArray::Array(arr) => {
if arr.len() > 4 {
return Err(validator::ValidationError::new(
"maximum 4 stop sequences allowed",
));
}
for s in arr {
if s.is_empty() {
return Err(validator::ValidationError::new(
"stop sequences cannot be empty",
));
}
}
}
}
Ok(())
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
#[serde(rename = "video_url")]
VideoUrl { video_url: VideoUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>, }
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
pub struct VideoUrl {
pub url: String,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_object")]
JsonObject,
#[serde(rename = "json_schema")]
JsonSchema { json_schema: JsonSchemaFormat },
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct JsonSchemaFormat {
pub name: String,
pub schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct ToolCallDelta {
pub index: u32,
pub id: Option<String>,
#[serde(rename = "type")]
pub tool_type: Option<String>,
pub function: Option<FunctionCallDelta>,
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct FunctionCallDelta {
pub name: Option<String>,
pub arguments: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoiceValue {
Auto,
Required,
None,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
#[serde(untagged)]
pub enum ToolChoice {
Value(ToolChoiceValue),
Function {
#[serde(rename = "type")]
tool_type: String, function: FunctionChoice,
},
AllowedTools {
#[serde(rename = "type")]
tool_type: String, mode: String, tools: Vec<ToolReference>,
},
}
impl Default for ToolChoice {
fn default() -> Self {
Self::Value(ToolChoiceValue::Auto)
}
}
impl ToolChoice {
pub fn serialize_to_string(tool_choice: Option<&ToolChoice>) -> String {
tool_choice
.map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
.unwrap_or_else(|| "auto".to_string())
}
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct FunctionChoice {
pub name: String,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ToolReference {
#[serde(rename = "function")]
Function { name: String },
#[serde(rename = "mcp")]
Mcp {
server_label: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
#[serde(rename = "file_search")]
FileSearch,
#[serde(rename = "web_search_preview")]
WebSearchPreview,
#[serde(rename = "computer_use_preview")]
ComputerUsePreview,
#[serde(rename = "code_interpreter")]
CodeInterpreter,
#[serde(rename = "image_generation")]
ImageGeneration,
}
impl ToolReference {
pub fn identifier(&self) -> String {
match self {
ToolReference::Function { name } => format!("function:{name}"),
ToolReference::Mcp { server_label, name } => {
if let Some(n) = name {
format!("mcp:{server_label}:{n}")
} else {
format!("mcp:{server_label}")
}
}
ToolReference::FileSearch => "file_search".to_string(),
ToolReference::WebSearchPreview => "web_search_preview".to_string(),
ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
ToolReference::CodeInterpreter => "code_interpreter".to_string(),
ToolReference::ImageGeneration => "image_generation".to_string(),
}
}
pub fn function_name(&self) -> Option<&str> {
match self {
ToolReference::Function { name } => Some(name.as_str()),
_ => None,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String, pub function: Function,
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct Function {
pub name: String,
pub description: Option<String>,
pub parameters: Value, pub strict: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String, pub function: FunctionCallResponse,
}
#[derive(Debug, Clone)]
pub enum FunctionCall {
None,
Auto,
Function { name: String },
}
impl Serialize for FunctionCall {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
FunctionCall::None => serializer.serialize_str("none"),
FunctionCall::Auto => serializer.serialize_str("auto"),
FunctionCall::Function { name } => {
use serde::ser::SerializeMap;
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry("name", name)?;
map.end()
}
}
}
}
impl<'de> Deserialize<'de> for FunctionCall {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let value = Value::deserialize(deserializer)?;
match &value {
Value::String(s) => match s.as_str() {
"none" => Ok(FunctionCall::None),
"auto" => Ok(FunctionCall::Auto),
other => Err(serde::de::Error::custom(format!(
"unknown function_call value: \"{other}\""
))),
},
Value::Object(map) => {
if let Some(Value::String(name)) = map.get("name") {
Ok(FunctionCall::Function { name: name.clone() })
} else {
Err(serde::de::Error::custom(
"function_call object must have a \"name\" string field",
))
}
}
_ => Err(serde::de::Error::custom(
"function_call must be a string or object",
)),
}
}
}
impl schemars::JsonSchema for FunctionCall {
fn schema_name() -> String {
"FunctionCall".to_string()
}
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
use schemars::schema::*;
let string_schema = SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(vec!["none".into(), "auto".into()]),
..Default::default()
};
let object_schema = SchemaObject {
instance_type: Some(InstanceType::Object.into()),
object: Some(Box::new(ObjectValidation {
properties: {
let mut map = schemars::Map::new();
map.insert("name".to_string(), gen.subschema_for::<String>());
map
},
required: {
let mut set = std::collections::BTreeSet::new();
set.insert("name".to_string());
set
},
..Default::default()
})),
..Default::default()
};
SchemaObject {
subschemas: Some(Box::new(SubschemaValidation {
any_of: Some(vec![string_schema.into(), object_schema.into()]),
..Default::default()
})),
..Default::default()
}
.into()
}
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct FunctionCallResponse {
pub name: String,
#[serde(default)]
pub arguments: Option<String>, }
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
impl Usage {
pub fn from_counts(prompt_tokens: u32, completion_tokens: u32) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
prompt_tokens_details: None,
completion_tokens_details: None,
}
}
pub fn with_cached_tokens(mut self, cached_tokens: u32) -> Self {
if cached_tokens > 0 {
self.prompt_tokens_details = Some(PromptTokenUsageInfo { cached_tokens });
}
self
}
pub fn with_reasoning_tokens(mut self, reasoning_tokens: u32) -> Self {
if reasoning_tokens > 0 {
self.completion_tokens_details = Some(CompletionTokensDetails {
reasoning_tokens: Some(reasoning_tokens),
accepted_prediction_tokens: None,
rejected_prediction_tokens: None,
});
}
self
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
pub accepted_prediction_tokens: Option<u32>,
pub rejected_prediction_tokens: Option<u32>,
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct UsageInfo {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub reasoning_tokens: Option<u32>,
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct PromptTokenUsageInfo {
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f32>>,
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
#[serde(untagged)]
pub enum ChatLogProbs {
Detailed {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<Vec<ChatLogProbsContent>>,
},
Raw(Value),
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct ChatLogProbsContent {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogProb>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct TopLogProb {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
pub param: Option<String>,
pub code: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
#[serde(untagged)]
pub enum InputIds {
Single(Vec<i32>),
Batch(Vec<Vec<i32>>),
}
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
#[serde(untagged)]
pub enum LoRAPath {
Single(Option<String>),
Batch(Vec<Option<String>>),
}
#[derive(Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct Redacted(pub String);
impl std::fmt::Debug for Redacted {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("[REDACTED]")
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct ResponsePrompt {
pub id: String,
pub variables: Option<HashMap<String, PromptVariable>>,
pub version: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(untagged)]
pub enum PromptVariable {
String(String),
Typed(PromptVariableTyped),
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(tag = "type")]
#[expect(
clippy::enum_variant_names,
reason = "variant names match OpenAI API spec"
)]
pub enum PromptVariableTyped {
#[serde(rename = "input_text")]
ResponseInputText { text: String },
#[serde(rename = "input_image")]
ResponseInputImage {
detail: Option<Detail>,
file_id: Option<String>,
image_url: Option<String>,
},
#[serde(rename = "input_file")]
ResponseInputFile {
file_data: Option<String>,
file_id: Option<String>,
file_url: Option<String>,
filename: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, schemars::JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum Detail {
Low,
High,
#[default]
Auto,
}
#[cfg(test)]
mod tests {
use serde::Deserialize;
use serde_json::json;
use super::*;
#[derive(Deserialize)]
struct NullableBoolTest {
#[serde(default, deserialize_with = "deserialize_null_as_false")]
field: bool,
}
#[test]
fn test_deserialize_null_as_false() {
let cases = [
(json!({"field": true}), true),
(json!({"field": false}), false),
(json!({"field": null}), false),
(json!({}), false),
];
for (input, expected) in cases {
let t: NullableBoolTest = serde_json::from_value(input).unwrap();
assert_eq!(t.field, expected);
}
}
#[test]
fn test_deserialize_null_as_false_rejects_non_bool() {
let result = serde_json::from_value::<NullableBoolTest>(json!({"field": "yes"}));
assert!(result.is_err());
}
}