use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use validator::Validate;
use super::{
common::{
default_model, default_true, validate_stop, ChatLogProbs, ContentPart, Function,
FunctionCall, FunctionChoice, GenerationRequest, ResponseFormat, StreamOptions,
StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, ToolReference,
Usage,
},
sampling_params::{validate_top_k_value, validate_top_p_value},
};
use crate::{
builders::{ChatCompletionResponseBuilder, ChatCompletionStreamResponseBuilder},
validated::Normalizable,
};
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "role")]
pub enum ChatMessage {
#[serde(rename = "system")]
System {
content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
#[serde(rename = "user")]
User {
content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
#[serde(rename = "assistant")]
Assistant {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<MessageContent>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
},
#[serde(rename = "tool")]
Tool {
content: MessageContent,
tool_call_id: String,
},
#[serde(rename = "function")]
Function { content: String, name: String },
#[serde(rename = "developer")]
Developer {
content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
impl MessageContent {
pub fn to_simple_string(&self) -> String {
match self {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => {
let mut result = String::new();
let mut first = true;
for part in parts {
if let ContentPart::Text { text } = part {
if !first {
result.push(' ');
}
result.push_str(text);
first = false;
}
}
result
}
}
}
#[inline]
pub fn append_text_to(&self, buffer: &mut String) -> bool {
match self {
MessageContent::Text(text) => {
if !text.is_empty() {
buffer.push_str(text);
true
} else {
false
}
}
MessageContent::Parts(parts) => {
let mut appended = false;
for part in parts {
if let ContentPart::Text { text } = part {
if !text.is_empty() {
if appended {
buffer.push(' ');
}
buffer.push_str(text);
appended = true;
}
}
}
appended
}
}
}
#[inline]
pub fn has_text(&self) -> bool {
match self {
MessageContent::Text(text) => !text.is_empty(),
MessageContent::Parts(parts) => parts
.iter()
.any(|part| matches!(part, ContentPart::Text { text } if !text.is_empty())),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)]
#[validate(schema(function = "validate_chat_cross_parameters"))]
pub struct ChatCompletionRequest {
#[validate(custom(function = "validate_messages"))]
pub messages: Vec<ChatMessage>,
#[serde(default = "default_model")]
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = -2.0, max = 2.0))]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[deprecated(note = "Use tool_choice instead")]
pub function_call: Option<FunctionCall>,
#[serde(skip_serializing_if = "Option::is_none")]
#[deprecated(note = "Use tools instead")]
pub functions: Option<Vec<Function>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>,
#[serde(default)]
pub logprobs: bool,
#[serde(skip_serializing_if = "Option::is_none")]
#[deprecated(note = "Use max_completion_tokens instead")]
#[validate(range(min = 1))]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 1))]
pub max_completion_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub modalities: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 1, max = 10))]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = -2.0, max = 2.0))]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub safety_identifier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[deprecated(note = "This feature is in Legacy mode")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_stop"))]
pub stop: Option<StringOrArray>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 2.0))]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0, max = 20))]
pub top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_top_p_value"))]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub verbosity: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_top_k_value"))]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 1))]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<u32>>,
#[serde(default)]
pub no_stop_trim: bool,
#[serde(default)]
pub ignore_eos: bool,
#[serde(default)]
pub continue_final_message: bool,
#[serde(default = "default_true")]
pub skip_special_tokens: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, Value>>,
#[serde(default = "default_true")]
pub separate_reasoning: bool,
#[serde(default = "default_true")]
pub stream_reasoning: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_template_kwargs: Option<HashMap<String, Value>>,
#[serde(default)]
pub return_hidden_states: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_seed: Option<u64>,
}
fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> {
if messages.is_empty() {
return Err(validator::ValidationError::new("messages cannot be empty"));
}
for msg in messages.iter() {
if let ChatMessage::User { content, .. } = msg {
match content {
MessageContent::Text(text) if text.is_empty() => {
return Err(validator::ValidationError::new(
"message content cannot be empty",
));
}
MessageContent::Parts(parts) if parts.is_empty() => {
return Err(validator::ValidationError::new(
"message content parts cannot be empty",
));
}
_ => {}
}
}
}
Ok(())
}
fn validate_chat_cross_parameters(
req: &ChatCompletionRequest,
) -> Result<(), validator::ValidationError> {
if req.top_logprobs.is_some() && !req.logprobs {
let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs");
e.message = Some("top_logprobs is only allowed when logprobs is enabled".into());
return Err(e);
}
if req.stream_options.is_some() && !req.stream {
let mut e = validator::ValidationError::new("stream_options_requires_stream");
e.message =
Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into());
return Err(e);
}
if let (Some(min), Some(max)) = (req.min_tokens, req.max_completion_tokens) {
if min > max {
let mut e = validator::ValidationError::new("min_tokens_exceeds_max");
e.message = Some("min_tokens cannot exceed max_tokens/max_completion_tokens".into());
return Err(e);
}
}
let has_json_format = matches!(
req.response_format,
Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
);
if has_json_format && req.regex.is_some() {
let mut e = validator::ValidationError::new("regex_conflicts_with_json");
e.message = Some("cannot use regex constraint with JSON response format".into());
return Err(e);
}
if has_json_format && req.ebnf.is_some() {
let mut e = validator::ValidationError::new("ebnf_conflicts_with_json");
e.message = Some("cannot use EBNF constraint with JSON response format".into());
return Err(e);
}
let constraint_count = [
req.regex.is_some(),
req.ebnf.is_some(),
matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })),
]
.iter()
.filter(|&&x| x)
.count();
if constraint_count > 1 {
let mut e = validator::ValidationError::new("multiple_constraints");
e.message = Some("only one structured output constraint (regex, ebnf, or json_schema) can be active at a time".into());
return Err(e);
}
if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format {
if json_schema.name.is_empty() {
let mut e = validator::ValidationError::new("json_schema_name_empty");
e.message = Some("JSON schema name cannot be empty".into());
return Err(e);
}
}
if let Some(ref tool_choice) = req.tool_choice {
let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty());
let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None));
if is_some_choice && !has_tools {
let mut e = validator::ValidationError::new("tool_choice_requires_tools");
e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into());
return Err(e);
}
if has_tools {
let tools = req.tools.as_ref().unwrap();
match tool_choice {
ToolChoice::Function { function, .. } => {
let function_exists = tools.iter().any(|tool| {
tool.tool_type == "function" && tool.function.name == function.name
});
if !function_exists {
let mut e =
validator::ValidationError::new("tool_choice_function_not_found");
e.message = Some(
format!(
"Invalid value for 'tool_choice': function '{}' not found in 'tools'.",
function.name
)
.into(),
);
return Err(e);
}
}
ToolChoice::AllowedTools {
mode,
tools: allowed_tools,
..
} => {
if mode != "auto" && mode != "required" {
let mut e = validator::ValidationError::new("tool_choice_invalid_mode");
e.message = Some(format!(
"Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{}'.",
mode
).into());
return Err(e);
}
for tool_ref in allowed_tools {
match tool_ref {
ToolReference::Function { name } => {
let tool_exists = tools.iter().any(|tool| {
tool.tool_type == "function" && tool.function.name == *name
});
if !tool_exists {
let mut e = validator::ValidationError::new(
"tool_choice_tool_not_found",
);
e.message = Some(
format!(
"Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.",
name
)
.into(),
);
return Err(e);
}
}
_ => {
let mut e = validator::ValidationError::new(
"tool_choice_invalid_tool_type",
);
e.message = Some(
format!(
"Invalid value for 'tool_choice.tools': Chat Completion API only supports function tools, got '{}'.",
tool_ref.identifier()
)
.into(),
);
return Err(e);
}
}
}
}
_ => {}
}
}
}
Ok(())
}
impl Normalizable for ChatCompletionRequest {
fn normalize(&mut self) {
#[allow(deprecated)]
if self.max_completion_tokens.is_none() && self.max_tokens.is_some() {
self.max_completion_tokens = self.max_tokens;
self.max_tokens = None; }
#[allow(deprecated)]
if self.tools.is_none() && self.functions.is_some() {
tracing::warn!("functions is deprecated, use tools instead");
self.tools = self.functions.as_ref().map(|functions| {
functions
.iter()
.map(|func| Tool {
tool_type: "function".to_string(),
function: func.clone(),
})
.collect()
});
self.functions = None; }
#[allow(deprecated)]
if self.tool_choice.is_none() && self.function_call.is_some() {
tracing::warn!("function_call is deprecated, use tool_choice instead");
self.tool_choice = self.function_call.as_ref().map(|fc| match fc {
FunctionCall::None => ToolChoice::Value(ToolChoiceValue::None),
FunctionCall::Auto => ToolChoice::Value(ToolChoiceValue::Auto),
FunctionCall::Function { name } => ToolChoice::Function {
tool_type: "function".to_string(),
function: FunctionChoice { name: name.clone() },
},
});
self.function_call = None; }
if self.tool_choice.is_none() {
if let Some(tools) = &self.tools {
let choice_value = if !tools.is_empty() {
ToolChoiceValue::Auto
} else {
ToolChoiceValue::None
};
self.tool_choice = Some(ToolChoice::Value(choice_value));
}
}
}
}
impl GenerationRequest for ChatCompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
let mut buffer = String::new();
let mut has_content = false;
for msg in &self.messages {
match msg {
ChatMessage::System { content, .. }
| ChatMessage::User { content, .. }
| ChatMessage::Tool { content, .. }
| ChatMessage::Developer { content, .. } => {
if has_content && content.has_text() {
buffer.push(' ');
}
if content.append_text_to(&mut buffer) {
has_content = true;
}
}
ChatMessage::Assistant {
content,
reasoning_content,
..
} => {
if let Some(c) = content {
if has_content && c.has_text() {
buffer.push(' ');
}
if c.append_text_to(&mut buffer) {
has_content = true;
}
}
if let Some(reasoning) = reasoning_content {
if !reasoning.is_empty() {
if has_content {
buffer.push(' ');
}
buffer.push_str(reasoning);
has_content = true;
}
}
}
ChatMessage::Function { content, .. } => {
if !content.is_empty() {
if has_content {
buffer.push(' ');
}
buffer.push_str(content);
has_content = true;
}
}
}
}
buffer
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String, pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
impl ChatCompletionResponse {
pub fn builder(
id: impl Into<String>,
model: impl Into<String>,
) -> ChatCompletionResponseBuilder {
ChatCompletionResponseBuilder::new(id, model)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionMessage {
pub role: String, #[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
pub reasoning_content: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatCompletionMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>, #[serde(skip_serializing_if = "Option::is_none")]
pub hidden_states: Option<Vec<f32>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String, pub created: u64,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
pub choices: Vec<ChatStreamChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
impl ChatCompletionStreamResponse {
pub fn builder(
id: impl Into<String>,
model: impl Into<String>,
) -> ChatCompletionStreamResponseBuilder {
ChatCompletionStreamResponseBuilder::new(id, model)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatMessageDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
pub reasoning_content: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatStreamChoice {
pub index: u32,
pub delta: ChatMessageDelta,
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>,
}