1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::{
8 common::{
9 default_model, default_true, validate_stop, ChatLogProbs, ContentPart, Function,
10 FunctionCall, FunctionChoice, GenerationRequest, ResponseFormat, StreamOptions,
11 StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, ToolReference,
12 Usage,
13 },
14 sampling_params::{validate_top_k_value, validate_top_p_value},
15};
16use crate::{
17 builders::{ChatCompletionResponseBuilder, ChatCompletionStreamResponseBuilder},
18 validated::Normalizable,
19};
20
21#[serde_with::skip_serializing_none]
26#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
27#[serde(tag = "role")]
28pub enum ChatMessage {
29 #[serde(rename = "system")]
30 System {
31 content: MessageContent,
32 name: Option<String>,
33 },
34 #[serde(rename = "user")]
35 User {
36 content: MessageContent,
37 name: Option<String>,
38 },
39 #[serde(rename = "assistant")]
40 Assistant {
41 content: Option<MessageContent>,
42 name: Option<String>,
43 tool_calls: Option<Vec<ToolCall>>,
44 reasoning_content: Option<String>,
46 },
47 #[serde(rename = "tool")]
48 Tool {
49 content: MessageContent,
50 tool_call_id: String,
51 },
52 #[serde(rename = "function")]
53 Function { content: String, name: String },
54 #[serde(rename = "developer")]
55 Developer {
56 content: MessageContent,
57 tools: Option<Vec<Tool>>,
58 name: Option<String>,
59 },
60}
61
62#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
63#[serde(untagged)]
64pub enum MessageContent {
65 Text(String),
66 Parts(Vec<ContentPart>),
67}
68
69impl MessageContent {
70 pub fn to_simple_string(&self) -> String {
75 match self {
76 MessageContent::Text(text) => text.clone(),
77 MessageContent::Parts(parts) => {
78 let mut result = String::new();
80 let mut first = true;
81 for part in parts {
82 if let ContentPart::Text { text } = part {
83 if !first {
84 result.push(' ');
85 }
86 result.push_str(text);
87 first = false;
88 }
89 }
90 result
91 }
92 }
93 }
94
95 #[inline]
98 pub fn append_text_to(&self, buffer: &mut String) -> bool {
99 match self {
100 MessageContent::Text(text) => {
101 if text.is_empty() {
102 false
103 } else {
104 buffer.push_str(text);
105 true
106 }
107 }
108 MessageContent::Parts(parts) => {
109 let mut appended = false;
110 for part in parts {
111 if let ContentPart::Text { text } = part {
112 if !text.is_empty() {
113 if appended {
114 buffer.push(' ');
115 }
116 buffer.push_str(text);
117 appended = true;
118 }
119 }
120 }
121 appended
122 }
123 }
124 }
125
126 #[inline]
128 pub fn has_text(&self) -> bool {
129 match self {
130 MessageContent::Text(text) => !text.is_empty(),
131 MessageContent::Parts(parts) => parts
132 .iter()
133 .any(|part| matches!(part, ContentPart::Text { text } if !text.is_empty())),
134 }
135 }
136}
137
138#[serde_with::skip_serializing_none]
143#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate, schemars::JsonSchema)]
144#[validate(schema(function = "validate_chat_cross_parameters"))]
145pub struct ChatCompletionRequest {
146 #[validate(custom(function = "validate_messages"))]
148 pub messages: Vec<ChatMessage>,
149
150 #[serde(default = "default_model")]
152 pub model: String,
153
154 #[validate(range(min = -2.0, max = 2.0))]
156 pub frequency_penalty: Option<f32>,
157
158 #[deprecated(note = "Use tool_choice instead")]
160 pub function_call: Option<FunctionCall>,
161
162 #[deprecated(note = "Use tools instead")]
164 pub functions: Option<Vec<Function>>,
165
166 pub logit_bias: Option<HashMap<String, f32>>,
168
169 #[serde(default)]
171 pub logprobs: bool,
172
173 #[deprecated(note = "Use max_completion_tokens instead")]
175 #[validate(range(min = 1))]
176 pub max_tokens: Option<u32>,
177
178 #[validate(range(min = 1))]
180 pub max_completion_tokens: Option<u32>,
181
182 pub metadata: Option<HashMap<String, String>>,
184
185 pub modalities: Option<Vec<String>>,
187
188 #[validate(range(min = 1, max = 10))]
190 pub n: Option<u32>,
191
192 pub parallel_tool_calls: Option<bool>,
194
195 #[validate(range(min = -2.0, max = 2.0))]
197 pub presence_penalty: Option<f32>,
198
199 pub prompt_cache_key: Option<String>,
201
202 pub reasoning_effort: Option<String>,
204
205 pub response_format: Option<ResponseFormat>,
207
208 pub safety_identifier: Option<String>,
210
211 #[deprecated(note = "This feature is in Legacy mode")]
213 pub seed: Option<i64>,
214
215 pub service_tier: Option<String>,
217
218 #[validate(custom(function = "validate_stop"))]
220 pub stop: Option<StringOrArray>,
221
222 #[serde(default)]
224 pub stream: bool,
225
226 pub stream_options: Option<StreamOptions>,
228
229 #[validate(range(min = 0.0, max = 2.0))]
231 pub temperature: Option<f32>,
232
233 pub tool_choice: Option<ToolChoice>,
235
236 pub tools: Option<Vec<Tool>>,
238
239 #[validate(range(min = 0, max = 20))]
241 pub top_logprobs: Option<u32>,
242
243 #[validate(custom(function = "validate_top_p_value"))]
245 pub top_p: Option<f32>,
246
247 pub verbosity: Option<i32>,
249
250 #[validate(custom(function = "validate_top_k_value"))]
258 pub top_k: Option<i32>,
259
260 #[validate(range(min = 0.0, max = 1.0))]
262 pub min_p: Option<f32>,
263
264 #[validate(range(min = 1))]
266 pub min_tokens: Option<u32>,
267
268 #[validate(range(min = 0.0, max = 2.0))]
270 pub repetition_penalty: Option<f32>,
271
272 pub regex: Option<String>,
274
275 pub ebnf: Option<String>,
277
278 pub stop_token_ids: Option<Vec<u32>>,
280
281 #[serde(default)]
283 pub no_stop_trim: bool,
284
285 #[serde(default)]
287 pub ignore_eos: bool,
288
289 #[serde(default)]
291 pub continue_final_message: bool,
292
293 #[serde(default = "default_true")]
295 pub skip_special_tokens: bool,
296
297 pub lora_path: Option<String>,
299
300 pub session_params: Option<HashMap<String, Value>>,
302
303 #[serde(default = "default_true")]
305 pub separate_reasoning: bool,
306
307 #[serde(default = "default_true")]
309 pub stream_reasoning: bool,
310
311 pub chat_template_kwargs: Option<HashMap<String, Value>>,
313
314 #[serde(default)]
316 pub return_hidden_states: bool,
317
318 pub sampling_seed: Option<u64>,
320}
321
322fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> {
328 if messages.is_empty() {
329 return Err(validator::ValidationError::new("messages cannot be empty"));
330 }
331
332 for msg in messages {
333 if let ChatMessage::User { content, .. } = msg {
334 match content {
335 MessageContent::Text(text) if text.is_empty() => {
336 return Err(validator::ValidationError::new(
337 "message content cannot be empty",
338 ));
339 }
340 MessageContent::Parts(parts) if parts.is_empty() => {
341 return Err(validator::ValidationError::new(
342 "message content parts cannot be empty",
343 ));
344 }
345 _ => {}
346 }
347 }
348 }
349 Ok(())
350}
351
352fn validate_chat_cross_parameters(
354 req: &ChatCompletionRequest,
355) -> Result<(), validator::ValidationError> {
356 if req.top_logprobs.is_some() && !req.logprobs {
358 let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs");
359 e.message = Some("top_logprobs is only allowed when logprobs is enabled".into());
360 return Err(e);
361 }
362
363 if req.stream_options.is_some() && !req.stream {
365 let mut e = validator::ValidationError::new("stream_options_requires_stream");
366 e.message =
367 Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into());
368 return Err(e);
369 }
370
371 if let (Some(min), Some(max)) = (req.min_tokens, req.max_completion_tokens) {
373 if min > max {
374 let mut e = validator::ValidationError::new("min_tokens_exceeds_max");
375 e.message = Some("min_tokens cannot exceed max_tokens/max_completion_tokens".into());
376 return Err(e);
377 }
378 }
379
380 let has_json_format = matches!(
382 req.response_format,
383 Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
384 );
385
386 if has_json_format && req.regex.is_some() {
387 let mut e = validator::ValidationError::new("regex_conflicts_with_json");
388 e.message = Some("cannot use regex constraint with JSON response format".into());
389 return Err(e);
390 }
391
392 if has_json_format && req.ebnf.is_some() {
393 let mut e = validator::ValidationError::new("ebnf_conflicts_with_json");
394 e.message = Some("cannot use EBNF constraint with JSON response format".into());
395 return Err(e);
396 }
397
398 let constraint_count = [
400 req.regex.is_some(),
401 req.ebnf.is_some(),
402 matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })),
403 ]
404 .iter()
405 .filter(|&&x| x)
406 .count();
407
408 if constraint_count > 1 {
409 let mut e = validator::ValidationError::new("multiple_constraints");
410 e.message = Some("only one structured output constraint (regex, ebnf, or json_schema) can be active at a time".into());
411 return Err(e);
412 }
413
414 if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format {
416 if json_schema.name.is_empty() {
417 let mut e = validator::ValidationError::new("json_schema_name_empty");
418 e.message = Some("JSON schema name cannot be empty".into());
419 return Err(e);
420 }
421 }
422
423 if let Some(ref tool_choice) = req.tool_choice {
425 let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty());
426
427 let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None));
429
430 if is_some_choice && !has_tools {
431 let mut e = validator::ValidationError::new("tool_choice_requires_tools");
432 e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into());
433 return Err(e);
434 }
435
436 if let Some(tools) = req.tools.as_ref().filter(|t| !t.is_empty()) {
438 match tool_choice {
439 ToolChoice::Function { function, .. } => {
440 let function_exists = tools.iter().any(|tool| {
442 tool.tool_type == "function" && tool.function.name == function.name
443 });
444
445 if !function_exists {
446 let mut e =
447 validator::ValidationError::new("tool_choice_function_not_found");
448 e.message = Some(
449 format!(
450 "Invalid value for 'tool_choice': function '{}' not found in 'tools'.",
451 function.name
452 )
453 .into(),
454 );
455 return Err(e);
456 }
457 }
458 ToolChoice::AllowedTools {
459 mode,
460 tools: allowed_tools,
461 ..
462 } => {
463 if mode != "auto" && mode != "required" {
465 let mut e = validator::ValidationError::new("tool_choice_invalid_mode");
466 e.message = Some(format!(
467 "Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{mode}'."
468 ).into());
469 return Err(e);
470 }
471
472 for tool_ref in allowed_tools {
474 match tool_ref {
475 ToolReference::Function { name } => {
476 let tool_exists = tools.iter().any(|tool| {
478 tool.tool_type == "function" && tool.function.name == *name
479 });
480
481 if !tool_exists {
482 let mut e = validator::ValidationError::new(
483 "tool_choice_tool_not_found",
484 );
485 e.message = Some(
486 format!(
487 "Invalid value for 'tool_choice.tools': tool '{name}' not found in 'tools'."
488 )
489 .into(),
490 );
491 return Err(e);
492 }
493 }
494 _ => {
495 let mut e = validator::ValidationError::new(
497 "tool_choice_invalid_tool_type",
498 );
499 e.message = Some(
500 format!(
501 "Invalid value for 'tool_choice.tools': Chat Completion API only supports function tools, got '{}'.",
502 tool_ref.identifier()
503 )
504 .into(),
505 );
506 return Err(e);
507 }
508 }
509 }
510 }
511 ToolChoice::Value(_) => {}
512 }
513 }
514 }
515
516 Ok(())
517}
518
519impl Normalizable for ChatCompletionRequest {
524 fn normalize(&mut self) {
529 #[expect(deprecated)]
531 if self.max_completion_tokens.is_none() && self.max_tokens.is_some() {
532 self.max_completion_tokens = self.max_tokens;
533 self.max_tokens = None; }
535
536 #[expect(deprecated)]
538 if self.tools.is_none() && self.functions.is_some() {
539 tracing::warn!("functions is deprecated, use tools instead");
540 self.tools = self.functions.as_ref().map(|functions| {
541 functions
542 .iter()
543 .map(|func| Tool {
544 tool_type: "function".to_string(),
545 function: func.clone(),
546 })
547 .collect()
548 });
549 self.functions = None; }
551
552 #[expect(deprecated)]
554 if self.tool_choice.is_none() && self.function_call.is_some() {
555 tracing::warn!("function_call is deprecated, use tool_choice instead");
556 self.tool_choice = self.function_call.as_ref().map(|fc| match fc {
557 FunctionCall::None => ToolChoice::Value(ToolChoiceValue::None),
558 FunctionCall::Auto => ToolChoice::Value(ToolChoiceValue::Auto),
559 FunctionCall::Function { name } => ToolChoice::Function {
560 tool_type: "function".to_string(),
561 function: FunctionChoice { name: name.clone() },
562 },
563 });
564 self.function_call = None; }
566
567 if self.tool_choice.is_none() {
569 if let Some(tools) = &self.tools {
570 let choice_value = if tools.is_empty() {
571 ToolChoiceValue::None
572 } else {
573 ToolChoiceValue::Auto
574 };
575 self.tool_choice = Some(ToolChoice::Value(choice_value));
576 }
577 }
579 }
580}
581
582impl GenerationRequest for ChatCompletionRequest {
587 fn is_stream(&self) -> bool {
588 self.stream
589 }
590
591 fn get_model(&self) -> Option<&str> {
592 Some(&self.model)
593 }
594
595 fn extract_text_for_routing(&self) -> String {
596 let mut buffer = String::new();
599 let mut has_content = false;
600
601 for msg in &self.messages {
602 match msg {
603 ChatMessage::System { content, .. }
604 | ChatMessage::User { content, .. }
605 | ChatMessage::Tool { content, .. }
606 | ChatMessage::Developer { content, .. } => {
607 if has_content && content.has_text() {
608 buffer.push(' ');
609 }
610 if content.append_text_to(&mut buffer) {
611 has_content = true;
612 }
613 }
614 ChatMessage::Assistant {
615 content,
616 reasoning_content,
617 ..
618 } => {
619 if let Some(c) = content {
621 if has_content && c.has_text() {
622 buffer.push(' ');
623 }
624 if c.append_text_to(&mut buffer) {
625 has_content = true;
626 }
627 }
628 if let Some(reasoning) = reasoning_content {
630 if !reasoning.is_empty() {
631 if has_content {
632 buffer.push(' ');
633 }
634 buffer.push_str(reasoning);
635 has_content = true;
636 }
637 }
638 }
639 ChatMessage::Function { content, .. } => {
640 if !content.is_empty() {
641 if has_content {
642 buffer.push(' ');
643 }
644 buffer.push_str(content);
645 has_content = true;
646 }
647 }
648 }
649 }
650
651 buffer
652 }
653}
654
655#[serde_with::skip_serializing_none]
660#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
661pub struct ChatCompletionResponse {
662 pub id: String,
663 pub object: String, pub created: u64,
665 pub model: String,
666 pub choices: Vec<ChatChoice>,
667 pub usage: Option<Usage>,
668 pub system_fingerprint: Option<String>,
669}
670
671impl ChatCompletionResponse {
672 pub fn builder(
674 id: impl Into<String>,
675 model: impl Into<String>,
676 ) -> ChatCompletionResponseBuilder {
677 ChatCompletionResponseBuilder::new(id, model)
678 }
679}
680
681#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
683pub struct ChatCompletionMessage {
684 pub role: String, #[serde(skip_serializing_if = "Option::is_none")]
686 pub content: Option<String>,
687 #[serde(skip_serializing_if = "Option::is_none")]
688 pub tool_calls: Option<Vec<ToolCall>>,
689 pub reasoning_content: Option<String>,
690 }
693
694#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
695pub struct ChatChoice {
696 pub index: u32,
697 pub message: ChatCompletionMessage,
698 #[serde(skip_serializing_if = "Option::is_none")]
699 pub logprobs: Option<ChatLogProbs>,
700 pub finish_reason: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
703 pub matched_stop: Option<Value>, #[serde(skip_serializing_if = "Option::is_none")]
706 pub hidden_states: Option<Vec<f32>>,
707}
708
709#[serde_with::skip_serializing_none]
710#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
711pub struct ChatCompletionStreamResponse {
712 pub id: String,
713 pub object: String, pub created: u64,
715 pub model: String,
716 pub system_fingerprint: Option<String>,
717 pub choices: Vec<ChatStreamChoice>,
718 pub usage: Option<Usage>,
719}
720
721impl ChatCompletionStreamResponse {
722 pub fn builder(
724 id: impl Into<String>,
725 model: impl Into<String>,
726 ) -> ChatCompletionStreamResponseBuilder {
727 ChatCompletionStreamResponseBuilder::new(id, model)
728 }
729}
730
731#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
733pub struct ChatMessageDelta {
734 #[serde(skip_serializing_if = "Option::is_none")]
735 pub role: Option<String>,
736 #[serde(skip_serializing_if = "Option::is_none")]
737 pub content: Option<String>,
738 #[serde(skip_serializing_if = "Option::is_none")]
739 pub tool_calls: Option<Vec<ToolCallDelta>>,
740 pub reasoning_content: Option<String>,
741}
742
743#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
744pub struct ChatStreamChoice {
745 pub index: u32,
746 pub delta: ChatMessageDelta,
747 pub logprobs: Option<ChatLogProbs>,
748 pub finish_reason: Option<String>,
749 #[serde(skip_serializing_if = "Option::is_none")]
750 pub matched_stop: Option<Value>,
751}