1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::{
8 common::{
9 default_model, default_true, validate_stop, ChatLogProbs, ContentPart, Function,
10 FunctionCall, FunctionChoice, GenerationRequest, ResponseFormat, StreamOptions,
11 StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, ToolReference,
12 Usage,
13 },
14 sampling_params::{validate_top_k_value, validate_top_p_value},
15};
16use crate::{
17 builders::{ChatCompletionResponseBuilder, ChatCompletionStreamResponseBuilder},
18 validated::Normalizable,
19};
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
26#[serde(tag = "role")]
27pub enum ChatMessage {
28 #[serde(rename = "system")]
29 System {
30 content: MessageContent,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 name: Option<String>,
33 },
34 #[serde(rename = "user")]
35 User {
36 content: MessageContent,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 name: Option<String>,
39 },
40 #[serde(rename = "assistant")]
41 Assistant {
42 #[serde(skip_serializing_if = "Option::is_none")]
43 content: Option<MessageContent>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 name: Option<String>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 tool_calls: Option<Vec<ToolCall>>,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 reasoning_content: Option<String>,
51 },
52 #[serde(rename = "tool")]
53 Tool {
54 content: MessageContent,
55 tool_call_id: String,
56 },
57 #[serde(rename = "function")]
58 Function { content: String, name: String },
59 #[serde(rename = "developer")]
60 Developer {
61 content: MessageContent,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 tools: Option<Vec<Tool>>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 name: Option<String>,
66 },
67}
68
69#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
70#[serde(untagged)]
71pub enum MessageContent {
72 Text(String),
73 Parts(Vec<ContentPart>),
74}
75
76impl MessageContent {
77 pub fn to_simple_string(&self) -> String {
82 match self {
83 MessageContent::Text(text) => text.clone(),
84 MessageContent::Parts(parts) => {
85 let mut result = String::new();
87 let mut first = true;
88 for part in parts {
89 if let ContentPart::Text { text } = part {
90 if !first {
91 result.push(' ');
92 }
93 result.push_str(text);
94 first = false;
95 }
96 }
97 result
98 }
99 }
100 }
101
102 #[inline]
105 pub fn append_text_to(&self, buffer: &mut String) -> bool {
106 match self {
107 MessageContent::Text(text) => {
108 if !text.is_empty() {
109 buffer.push_str(text);
110 true
111 } else {
112 false
113 }
114 }
115 MessageContent::Parts(parts) => {
116 let mut appended = false;
117 for part in parts {
118 if let ContentPart::Text { text } = part {
119 if !text.is_empty() {
120 if appended {
121 buffer.push(' ');
122 }
123 buffer.push_str(text);
124 appended = true;
125 }
126 }
127 }
128 appended
129 }
130 }
131 }
132
133 #[inline]
135 pub fn has_text(&self) -> bool {
136 match self {
137 MessageContent::Text(text) => !text.is_empty(),
138 MessageContent::Parts(parts) => parts
139 .iter()
140 .any(|part| matches!(part, ContentPart::Text { text } if !text.is_empty())),
141 }
142 }
143}
144
145#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)]
150#[validate(schema(function = "validate_chat_cross_parameters"))]
151pub struct ChatCompletionRequest {
152 #[validate(custom(function = "validate_messages"))]
154 pub messages: Vec<ChatMessage>,
155
156 #[serde(default = "default_model")]
158 pub model: String,
159
160 #[serde(skip_serializing_if = "Option::is_none")]
162 #[validate(range(min = -2.0, max = 2.0))]
163 pub frequency_penalty: Option<f32>,
164
165 #[serde(skip_serializing_if = "Option::is_none")]
167 #[deprecated(note = "Use tool_choice instead")]
168 pub function_call: Option<FunctionCall>,
169
170 #[serde(skip_serializing_if = "Option::is_none")]
172 #[deprecated(note = "Use tools instead")]
173 pub functions: Option<Vec<Function>>,
174
175 #[serde(skip_serializing_if = "Option::is_none")]
177 pub logit_bias: Option<HashMap<String, f32>>,
178
179 #[serde(default)]
181 pub logprobs: bool,
182
183 #[serde(skip_serializing_if = "Option::is_none")]
185 #[deprecated(note = "Use max_completion_tokens instead")]
186 #[validate(range(min = 1))]
187 pub max_tokens: Option<u32>,
188
189 #[serde(skip_serializing_if = "Option::is_none")]
191 #[validate(range(min = 1))]
192 pub max_completion_tokens: Option<u32>,
193
194 #[serde(skip_serializing_if = "Option::is_none")]
196 pub metadata: Option<HashMap<String, String>>,
197
198 #[serde(skip_serializing_if = "Option::is_none")]
200 pub modalities: Option<Vec<String>>,
201
202 #[serde(skip_serializing_if = "Option::is_none")]
204 #[validate(range(min = 1, max = 10))]
205 pub n: Option<u32>,
206
207 #[serde(skip_serializing_if = "Option::is_none")]
209 pub parallel_tool_calls: Option<bool>,
210
211 #[serde(skip_serializing_if = "Option::is_none")]
213 #[validate(range(min = -2.0, max = 2.0))]
214 pub presence_penalty: Option<f32>,
215
216 #[serde(skip_serializing_if = "Option::is_none")]
218 pub prompt_cache_key: Option<String>,
219
220 #[serde(skip_serializing_if = "Option::is_none")]
222 pub reasoning_effort: Option<String>,
223
224 #[serde(skip_serializing_if = "Option::is_none")]
226 pub response_format: Option<ResponseFormat>,
227
228 #[serde(skip_serializing_if = "Option::is_none")]
230 pub safety_identifier: Option<String>,
231
232 #[serde(skip_serializing_if = "Option::is_none")]
234 #[deprecated(note = "This feature is in Legacy mode")]
235 pub seed: Option<i64>,
236
237 #[serde(skip_serializing_if = "Option::is_none")]
239 pub service_tier: Option<String>,
240
241 #[serde(skip_serializing_if = "Option::is_none")]
243 #[validate(custom(function = "validate_stop"))]
244 pub stop: Option<StringOrArray>,
245
246 #[serde(default)]
248 pub stream: bool,
249
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub stream_options: Option<StreamOptions>,
253
254 #[serde(skip_serializing_if = "Option::is_none")]
256 #[validate(range(min = 0.0, max = 2.0))]
257 pub temperature: Option<f32>,
258
259 #[serde(skip_serializing_if = "Option::is_none")]
261 pub tool_choice: Option<ToolChoice>,
262
263 #[serde(skip_serializing_if = "Option::is_none")]
265 pub tools: Option<Vec<Tool>>,
266
267 #[serde(skip_serializing_if = "Option::is_none")]
269 #[validate(range(min = 0, max = 20))]
270 pub top_logprobs: Option<u32>,
271
272 #[serde(skip_serializing_if = "Option::is_none")]
274 #[validate(custom(function = "validate_top_p_value"))]
275 pub top_p: Option<f32>,
276
277 #[serde(skip_serializing_if = "Option::is_none")]
279 pub verbosity: Option<i32>,
280
281 #[serde(skip_serializing_if = "Option::is_none")]
289 #[validate(custom(function = "validate_top_k_value"))]
290 pub top_k: Option<i32>,
291
292 #[serde(skip_serializing_if = "Option::is_none")]
294 #[validate(range(min = 0.0, max = 1.0))]
295 pub min_p: Option<f32>,
296
297 #[serde(skip_serializing_if = "Option::is_none")]
299 #[validate(range(min = 1))]
300 pub min_tokens: Option<u32>,
301
302 #[serde(skip_serializing_if = "Option::is_none")]
304 #[validate(range(min = 0.0, max = 2.0))]
305 pub repetition_penalty: Option<f32>,
306
307 #[serde(skip_serializing_if = "Option::is_none")]
309 pub regex: Option<String>,
310
311 #[serde(skip_serializing_if = "Option::is_none")]
313 pub ebnf: Option<String>,
314
315 #[serde(skip_serializing_if = "Option::is_none")]
317 pub stop_token_ids: Option<Vec<u32>>,
318
319 #[serde(default)]
321 pub no_stop_trim: bool,
322
323 #[serde(default)]
325 pub ignore_eos: bool,
326
327 #[serde(default)]
329 pub continue_final_message: bool,
330
331 #[serde(default = "default_true")]
333 pub skip_special_tokens: bool,
334
335 #[serde(skip_serializing_if = "Option::is_none")]
337 pub lora_path: Option<String>,
338
339 #[serde(skip_serializing_if = "Option::is_none")]
341 pub session_params: Option<HashMap<String, Value>>,
342
343 #[serde(default = "default_true")]
345 pub separate_reasoning: bool,
346
347 #[serde(default = "default_true")]
349 pub stream_reasoning: bool,
350
351 #[serde(skip_serializing_if = "Option::is_none")]
353 pub chat_template_kwargs: Option<HashMap<String, Value>>,
354
355 #[serde(default)]
357 pub return_hidden_states: bool,
358
359 #[serde(skip_serializing_if = "Option::is_none")]
361 pub sampling_seed: Option<u64>,
362}
363
364fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> {
370 if messages.is_empty() {
371 return Err(validator::ValidationError::new("messages cannot be empty"));
372 }
373
374 for msg in messages.iter() {
375 if let ChatMessage::User { content, .. } = msg {
376 match content {
377 MessageContent::Text(text) if text.is_empty() => {
378 return Err(validator::ValidationError::new(
379 "message content cannot be empty",
380 ));
381 }
382 MessageContent::Parts(parts) if parts.is_empty() => {
383 return Err(validator::ValidationError::new(
384 "message content parts cannot be empty",
385 ));
386 }
387 _ => {}
388 }
389 }
390 }
391 Ok(())
392}
393
394fn validate_chat_cross_parameters(
396 req: &ChatCompletionRequest,
397) -> Result<(), validator::ValidationError> {
398 if req.top_logprobs.is_some() && !req.logprobs {
400 let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs");
401 e.message = Some("top_logprobs is only allowed when logprobs is enabled".into());
402 return Err(e);
403 }
404
405 if req.stream_options.is_some() && !req.stream {
407 let mut e = validator::ValidationError::new("stream_options_requires_stream");
408 e.message =
409 Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into());
410 return Err(e);
411 }
412
413 if let (Some(min), Some(max)) = (req.min_tokens, req.max_completion_tokens) {
415 if min > max {
416 let mut e = validator::ValidationError::new("min_tokens_exceeds_max");
417 e.message = Some("min_tokens cannot exceed max_tokens/max_completion_tokens".into());
418 return Err(e);
419 }
420 }
421
422 let has_json_format = matches!(
424 req.response_format,
425 Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
426 );
427
428 if has_json_format && req.regex.is_some() {
429 let mut e = validator::ValidationError::new("regex_conflicts_with_json");
430 e.message = Some("cannot use regex constraint with JSON response format".into());
431 return Err(e);
432 }
433
434 if has_json_format && req.ebnf.is_some() {
435 let mut e = validator::ValidationError::new("ebnf_conflicts_with_json");
436 e.message = Some("cannot use EBNF constraint with JSON response format".into());
437 return Err(e);
438 }
439
440 let constraint_count = [
442 req.regex.is_some(),
443 req.ebnf.is_some(),
444 matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })),
445 ]
446 .iter()
447 .filter(|&&x| x)
448 .count();
449
450 if constraint_count > 1 {
451 let mut e = validator::ValidationError::new("multiple_constraints");
452 e.message = Some("only one structured output constraint (regex, ebnf, or json_schema) can be active at a time".into());
453 return Err(e);
454 }
455
456 if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format {
458 if json_schema.name.is_empty() {
459 let mut e = validator::ValidationError::new("json_schema_name_empty");
460 e.message = Some("JSON schema name cannot be empty".into());
461 return Err(e);
462 }
463 }
464
465 if let Some(ref tool_choice) = req.tool_choice {
467 let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty());
468
469 let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None));
471
472 if is_some_choice && !has_tools {
473 let mut e = validator::ValidationError::new("tool_choice_requires_tools");
474 e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into());
475 return Err(e);
476 }
477
478 if has_tools {
480 let tools = req.tools.as_ref().unwrap();
481
482 match tool_choice {
483 ToolChoice::Function { function, .. } => {
484 let function_exists = tools.iter().any(|tool| {
486 tool.tool_type == "function" && tool.function.name == function.name
487 });
488
489 if !function_exists {
490 let mut e =
491 validator::ValidationError::new("tool_choice_function_not_found");
492 e.message = Some(
493 format!(
494 "Invalid value for 'tool_choice': function '{}' not found in 'tools'.",
495 function.name
496 )
497 .into(),
498 );
499 return Err(e);
500 }
501 }
502 ToolChoice::AllowedTools {
503 mode,
504 tools: allowed_tools,
505 ..
506 } => {
507 if mode != "auto" && mode != "required" {
509 let mut e = validator::ValidationError::new("tool_choice_invalid_mode");
510 e.message = Some(format!(
511 "Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{}'.",
512 mode
513 ).into());
514 return Err(e);
515 }
516
517 for tool_ref in allowed_tools {
519 match tool_ref {
520 ToolReference::Function { name } => {
521 let tool_exists = tools.iter().any(|tool| {
523 tool.tool_type == "function" && tool.function.name == *name
524 });
525
526 if !tool_exists {
527 let mut e = validator::ValidationError::new(
528 "tool_choice_tool_not_found",
529 );
530 e.message = Some(
531 format!(
532 "Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.",
533 name
534 )
535 .into(),
536 );
537 return Err(e);
538 }
539 }
540 _ => {
541 let mut e = validator::ValidationError::new(
543 "tool_choice_invalid_tool_type",
544 );
545 e.message = Some(
546 format!(
547 "Invalid value for 'tool_choice.tools': Chat Completion API only supports function tools, got '{}'.",
548 tool_ref.identifier()
549 )
550 .into(),
551 );
552 return Err(e);
553 }
554 }
555 }
556 }
557 _ => {}
558 }
559 }
560 }
561
562 Ok(())
563}
564
565impl Normalizable for ChatCompletionRequest {
570 fn normalize(&mut self) {
575 #[allow(deprecated)]
577 if self.max_completion_tokens.is_none() && self.max_tokens.is_some() {
578 self.max_completion_tokens = self.max_tokens;
579 self.max_tokens = None; }
581
582 #[allow(deprecated)]
584 if self.tools.is_none() && self.functions.is_some() {
585 tracing::warn!("functions is deprecated, use tools instead");
586 self.tools = self.functions.as_ref().map(|functions| {
587 functions
588 .iter()
589 .map(|func| Tool {
590 tool_type: "function".to_string(),
591 function: func.clone(),
592 })
593 .collect()
594 });
595 self.functions = None; }
597
598 #[allow(deprecated)]
600 if self.tool_choice.is_none() && self.function_call.is_some() {
601 tracing::warn!("function_call is deprecated, use tool_choice instead");
602 self.tool_choice = self.function_call.as_ref().map(|fc| match fc {
603 FunctionCall::None => ToolChoice::Value(ToolChoiceValue::None),
604 FunctionCall::Auto => ToolChoice::Value(ToolChoiceValue::Auto),
605 FunctionCall::Function { name } => ToolChoice::Function {
606 tool_type: "function".to_string(),
607 function: FunctionChoice { name: name.clone() },
608 },
609 });
610 self.function_call = None; }
612
613 if self.tool_choice.is_none() {
615 if let Some(tools) = &self.tools {
616 let choice_value = if !tools.is_empty() {
617 ToolChoiceValue::Auto
618 } else {
619 ToolChoiceValue::None
620 };
621 self.tool_choice = Some(ToolChoice::Value(choice_value));
622 }
623 }
625 }
626}
627
628impl GenerationRequest for ChatCompletionRequest {
633 fn is_stream(&self) -> bool {
634 self.stream
635 }
636
637 fn get_model(&self) -> Option<&str> {
638 Some(&self.model)
639 }
640
641 fn extract_text_for_routing(&self) -> String {
642 let mut buffer = String::new();
645 let mut has_content = false;
646
647 for msg in &self.messages {
648 match msg {
649 ChatMessage::System { content, .. }
650 | ChatMessage::User { content, .. }
651 | ChatMessage::Tool { content, .. }
652 | ChatMessage::Developer { content, .. } => {
653 if has_content && content.has_text() {
654 buffer.push(' ');
655 }
656 if content.append_text_to(&mut buffer) {
657 has_content = true;
658 }
659 }
660 ChatMessage::Assistant {
661 content,
662 reasoning_content,
663 ..
664 } => {
665 if let Some(c) = content {
667 if has_content && c.has_text() {
668 buffer.push(' ');
669 }
670 if c.append_text_to(&mut buffer) {
671 has_content = true;
672 }
673 }
674 if let Some(reasoning) = reasoning_content {
676 if !reasoning.is_empty() {
677 if has_content {
678 buffer.push(' ');
679 }
680 buffer.push_str(reasoning);
681 has_content = true;
682 }
683 }
684 }
685 ChatMessage::Function { content, .. } => {
686 if !content.is_empty() {
687 if has_content {
688 buffer.push(' ');
689 }
690 buffer.push_str(content);
691 has_content = true;
692 }
693 }
694 }
695 }
696
697 buffer
698 }
699}
700
701#[derive(Debug, Clone, Deserialize, Serialize)]
706pub struct ChatCompletionResponse {
707 pub id: String,
708 pub object: String, pub created: u64,
710 pub model: String,
711 pub choices: Vec<ChatChoice>,
712 #[serde(skip_serializing_if = "Option::is_none")]
713 pub usage: Option<Usage>,
714 #[serde(skip_serializing_if = "Option::is_none")]
715 pub system_fingerprint: Option<String>,
716}
717
718impl ChatCompletionResponse {
719 pub fn builder(
721 id: impl Into<String>,
722 model: impl Into<String>,
723 ) -> ChatCompletionResponseBuilder {
724 ChatCompletionResponseBuilder::new(id, model)
725 }
726}
727
728#[derive(Debug, Clone, Deserialize, Serialize)]
730pub struct ChatCompletionMessage {
731 pub role: String, #[serde(skip_serializing_if = "Option::is_none")]
733 pub content: Option<String>,
734 #[serde(skip_serializing_if = "Option::is_none")]
735 pub tool_calls: Option<Vec<ToolCall>>,
736 pub reasoning_content: Option<String>,
737 }
740
741#[derive(Debug, Clone, Deserialize, Serialize)]
742pub struct ChatChoice {
743 pub index: u32,
744 pub message: ChatCompletionMessage,
745 #[serde(skip_serializing_if = "Option::is_none")]
746 pub logprobs: Option<ChatLogProbs>,
747 pub finish_reason: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
750 pub matched_stop: Option<Value>, #[serde(skip_serializing_if = "Option::is_none")]
753 pub hidden_states: Option<Vec<f32>>,
754}
755
756#[derive(Debug, Clone, Deserialize, Serialize)]
757pub struct ChatCompletionStreamResponse {
758 pub id: String,
759 pub object: String, pub created: u64,
761 pub model: String,
762 #[serde(skip_serializing_if = "Option::is_none")]
763 pub system_fingerprint: Option<String>,
764 pub choices: Vec<ChatStreamChoice>,
765 #[serde(skip_serializing_if = "Option::is_none")]
766 pub usage: Option<Usage>,
767}
768
769impl ChatCompletionStreamResponse {
770 pub fn builder(
772 id: impl Into<String>,
773 model: impl Into<String>,
774 ) -> ChatCompletionStreamResponseBuilder {
775 ChatCompletionStreamResponseBuilder::new(id, model)
776 }
777}
778
779#[derive(Debug, Clone, Deserialize, Serialize)]
781pub struct ChatMessageDelta {
782 #[serde(skip_serializing_if = "Option::is_none")]
783 pub role: Option<String>,
784 #[serde(skip_serializing_if = "Option::is_none")]
785 pub content: Option<String>,
786 #[serde(skip_serializing_if = "Option::is_none")]
787 pub tool_calls: Option<Vec<ToolCallDelta>>,
788 pub reasoning_content: Option<String>,
789}
790
791#[derive(Debug, Clone, Deserialize, Serialize)]
792pub struct ChatStreamChoice {
793 pub index: u32,
794 pub delta: ChatMessageDelta,
795 pub logprobs: Option<ChatLogProbs>,
796 pub finish_reason: Option<String>,
797 #[serde(skip_serializing_if = "Option::is_none")]
798 pub matched_stop: Option<Value>,
799}