use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use validator;
use super::UNKNOWN_MODEL_ID;
pub(crate) fn default_model() -> String {
UNKNOWN_MODEL_ID.to_string()
}
pub fn default_true() -> bool {
true
}
pub trait GenerationRequest: Send + Sync {
fn is_stream(&self) -> bool;
fn get_model(&self) -> Option<&str>;
fn extract_text_for_routing(&self) -> String;
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum StringOrArray {
String(String),
Array(Vec<String>),
}
impl StringOrArray {
pub fn len(&self) -> usize {
match self {
StringOrArray::String(_) => 1,
StringOrArray::Array(arr) => arr.len(),
}
}
pub fn is_empty(&self) -> bool {
match self {
StringOrArray::String(s) => s.is_empty(),
StringOrArray::Array(arr) => arr.is_empty(),
}
}
pub fn to_vec(&self) -> Vec<String> {
match self {
StringOrArray::String(s) => vec![s.clone()],
StringOrArray::Array(arr) => arr.clone(),
}
}
pub fn iter(&self) -> StringOrArrayIter<'_> {
StringOrArrayIter {
inner: self,
index: 0,
}
}
pub fn first(&self) -> Option<&str> {
match self {
StringOrArray::String(s) => {
if s.is_empty() {
None
} else {
Some(s)
}
}
StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
}
}
}
pub struct StringOrArrayIter<'a> {
inner: &'a StringOrArray,
index: usize,
}
impl<'a> Iterator for StringOrArrayIter<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
match self.inner {
StringOrArray::String(s) => {
if self.index == 0 {
self.index = 1;
Some(s.as_str())
} else {
None
}
}
StringOrArray::Array(arr) => {
if self.index < arr.len() {
let item = &arr[self.index];
self.index += 1;
Some(item.as_str())
} else {
None
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = match self.inner {
StringOrArray::String(_) => 1 - self.index,
StringOrArray::Array(arr) => arr.len() - self.index,
};
(remaining, Some(remaining))
}
}
impl<'a> ExactSizeIterator for StringOrArrayIter<'a> {}
pub fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
match stop {
StringOrArray::String(s) => {
if s.is_empty() {
return Err(validator::ValidationError::new(
"stop sequences cannot be empty",
));
}
}
StringOrArray::Array(arr) => {
if arr.len() > 4 {
return Err(validator::ValidationError::new(
"maximum 4 stop sequences allowed",
));
}
for s in arr {
if s.is_empty() {
return Err(validator::ValidationError::new(
"stop sequences cannot be empty",
));
}
}
}
}
Ok(())
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
#[serde(rename = "video_url")]
VideoUrl { video_url: VideoUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>, }
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct VideoUrl {
pub url: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_object")]
JsonObject,
#[serde(rename = "json_schema")]
JsonSchema { json_schema: JsonSchemaFormat },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct JsonSchemaFormat {
pub name: String,
pub schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
pub tool_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoiceValue {
Auto,
Required,
None,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ToolChoice {
Value(ToolChoiceValue),
Function {
#[serde(rename = "type")]
tool_type: String, function: FunctionChoice,
},
AllowedTools {
#[serde(rename = "type")]
tool_type: String, mode: String, tools: Vec<ToolReference>,
},
}
impl Default for ToolChoice {
fn default() -> Self {
Self::Value(ToolChoiceValue::Auto)
}
}
impl ToolChoice {
pub fn serialize_to_string(tool_choice: &Option<ToolChoice>) -> String {
tool_choice
.as_ref()
.map(|tc| serde_json::to_string(tc).unwrap_or_else(|_| "auto".to_string()))
.unwrap_or_else(|| "auto".to_string())
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionChoice {
pub name: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ToolReference {
#[serde(rename = "function")]
Function { name: String },
#[serde(rename = "mcp")]
Mcp {
server_label: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
#[serde(rename = "file_search")]
FileSearch,
#[serde(rename = "web_search_preview")]
WebSearchPreview,
#[serde(rename = "computer_use_preview")]
ComputerUsePreview,
#[serde(rename = "code_interpreter")]
CodeInterpreter,
#[serde(rename = "image_generation")]
ImageGeneration,
}
impl ToolReference {
pub fn identifier(&self) -> String {
match self {
ToolReference::Function { name } => format!("function:{}", name),
ToolReference::Mcp { server_label, name } => {
if let Some(n) = name {
format!("mcp:{}:{}", server_label, n)
} else {
format!("mcp:{}", server_label)
}
}
ToolReference::FileSearch => "file_search".to_string(),
ToolReference::WebSearchPreview => "web_search_preview".to_string(),
ToolReference::ComputerUsePreview => "computer_use_preview".to_string(),
ToolReference::CodeInterpreter => "code_interpreter".to_string(),
ToolReference::ImageGeneration => "image_generation".to_string(),
}
}
pub fn function_name(&self) -> Option<&str> {
match self {
ToolReference::Function { name } => Some(name.as_str()),
_ => None,
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String, pub function: Function,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value, #[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String, pub function: FunctionCallResponse,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum FunctionCall {
None,
Auto,
Function { name: String },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallResponse {
pub name: String,
#[serde(default)]
pub arguments: Option<String>, }
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UsageInfo {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PromptTokenUsageInfo {
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f32>>,
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ChatLogProbs {
Detailed {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<Vec<ChatLogProbsContent>>,
},
Raw(Value),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbsContent {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogProb>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TopLogProb {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum InputIds {
Single(Vec<i32>),
Batch(Vec<Vec<i32>>),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum LoRAPath {
Single(Option<String>),
Batch(Vec<Option<String>>),
}