1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::{
8 common::{
9 default_model, default_true, validate_stop, ChatLogProbs, ContentPart, Function,
10 FunctionCall, FunctionChoice, GenerationRequest, ResponseFormat, StreamOptions,
11 StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, ToolReference,
12 Usage,
13 },
14 sampling_params::{validate_top_k_value, validate_top_p_value},
15};
16use crate::{
17 builders::{ChatCompletionResponseBuilder, ChatCompletionStreamResponseBuilder},
18 validated::Normalizable,
19};
20
21#[serde_with::skip_serializing_none]
26#[derive(Debug, Clone, Deserialize, Serialize)]
27#[serde(tag = "role")]
28pub enum ChatMessage {
29 #[serde(rename = "system")]
30 System {
31 content: MessageContent,
32 name: Option<String>,
33 },
34 #[serde(rename = "user")]
35 User {
36 content: MessageContent,
37 name: Option<String>,
38 },
39 #[serde(rename = "assistant")]
40 Assistant {
41 content: Option<MessageContent>,
42 name: Option<String>,
43 tool_calls: Option<Vec<ToolCall>>,
44 reasoning_content: Option<String>,
46 },
47 #[serde(rename = "tool")]
48 Tool {
49 content: MessageContent,
50 tool_call_id: String,
51 },
52 #[serde(rename = "function")]
53 Function { content: String, name: String },
54 #[serde(rename = "developer")]
55 Developer {
56 content: MessageContent,
57 tools: Option<Vec<Tool>>,
58 name: Option<String>,
59 },
60}
61
62#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
63#[serde(untagged)]
64pub enum MessageContent {
65 Text(String),
66 Parts(Vec<ContentPart>),
67}
68
69impl MessageContent {
70 pub fn to_simple_string(&self) -> String {
75 match self {
76 MessageContent::Text(text) => text.clone(),
77 MessageContent::Parts(parts) => {
78 let mut result = String::new();
80 let mut first = true;
81 for part in parts {
82 if let ContentPart::Text { text } = part {
83 if !first {
84 result.push(' ');
85 }
86 result.push_str(text);
87 first = false;
88 }
89 }
90 result
91 }
92 }
93 }
94
95 #[inline]
98 pub fn append_text_to(&self, buffer: &mut String) -> bool {
99 match self {
100 MessageContent::Text(text) => {
101 if !text.is_empty() {
102 buffer.push_str(text);
103 true
104 } else {
105 false
106 }
107 }
108 MessageContent::Parts(parts) => {
109 let mut appended = false;
110 for part in parts {
111 if let ContentPart::Text { text } = part {
112 if !text.is_empty() {
113 if appended {
114 buffer.push(' ');
115 }
116 buffer.push_str(text);
117 appended = true;
118 }
119 }
120 }
121 appended
122 }
123 }
124 }
125
126 #[inline]
128 pub fn has_text(&self) -> bool {
129 match self {
130 MessageContent::Text(text) => !text.is_empty(),
131 MessageContent::Parts(parts) => parts
132 .iter()
133 .any(|part| matches!(part, ContentPart::Text { text } if !text.is_empty())),
134 }
135 }
136}
137
138#[serde_with::skip_serializing_none]
143#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)]
144#[validate(schema(function = "validate_chat_cross_parameters"))]
145pub struct ChatCompletionRequest {
146 #[validate(custom(function = "validate_messages"))]
148 pub messages: Vec<ChatMessage>,
149
150 #[serde(default = "default_model")]
152 pub model: String,
153
154 #[validate(range(min = -2.0, max = 2.0))]
156 pub frequency_penalty: Option<f32>,
157
158 #[deprecated(note = "Use tool_choice instead")]
160 pub function_call: Option<FunctionCall>,
161
162 #[deprecated(note = "Use tools instead")]
164 pub functions: Option<Vec<Function>>,
165
166 pub logit_bias: Option<HashMap<String, f32>>,
168
169 #[serde(default)]
171 pub logprobs: bool,
172
173 #[deprecated(note = "Use max_completion_tokens instead")]
175 #[validate(range(min = 1))]
176 pub max_tokens: Option<u32>,
177
178 #[validate(range(min = 1))]
180 pub max_completion_tokens: Option<u32>,
181
182 pub metadata: Option<HashMap<String, String>>,
184
185 pub modalities: Option<Vec<String>>,
187
188 #[validate(range(min = 1, max = 10))]
190 pub n: Option<u32>,
191
192 pub parallel_tool_calls: Option<bool>,
194
195 #[validate(range(min = -2.0, max = 2.0))]
197 pub presence_penalty: Option<f32>,
198
199 pub prompt_cache_key: Option<String>,
201
202 pub reasoning_effort: Option<String>,
204
205 pub response_format: Option<ResponseFormat>,
207
208 pub safety_identifier: Option<String>,
210
211 #[deprecated(note = "This feature is in Legacy mode")]
213 pub seed: Option<i64>,
214
215 pub service_tier: Option<String>,
217
218 #[validate(custom(function = "validate_stop"))]
220 pub stop: Option<StringOrArray>,
221
222 #[serde(default)]
224 pub stream: bool,
225
226 pub stream_options: Option<StreamOptions>,
228
229 #[validate(range(min = 0.0, max = 2.0))]
231 pub temperature: Option<f32>,
232
233 pub tool_choice: Option<ToolChoice>,
235
236 pub tools: Option<Vec<Tool>>,
238
239 #[validate(range(min = 0, max = 20))]
241 pub top_logprobs: Option<u32>,
242
243 #[validate(custom(function = "validate_top_p_value"))]
245 pub top_p: Option<f32>,
246
247 pub verbosity: Option<i32>,
249
250 #[validate(custom(function = "validate_top_k_value"))]
258 pub top_k: Option<i32>,
259
260 #[validate(range(min = 0.0, max = 1.0))]
262 pub min_p: Option<f32>,
263
264 #[validate(range(min = 1))]
266 pub min_tokens: Option<u32>,
267
268 #[validate(range(min = 0.0, max = 2.0))]
270 pub repetition_penalty: Option<f32>,
271
272 pub regex: Option<String>,
274
275 pub ebnf: Option<String>,
277
278 pub stop_token_ids: Option<Vec<u32>>,
280
281 #[serde(default)]
283 pub no_stop_trim: bool,
284
285 #[serde(default)]
287 pub ignore_eos: bool,
288
289 #[serde(default)]
291 pub continue_final_message: bool,
292
293 #[serde(default = "default_true")]
295 pub skip_special_tokens: bool,
296
297 pub lora_path: Option<String>,
299
300 pub session_params: Option<HashMap<String, Value>>,
302
303 #[serde(default = "default_true")]
305 pub separate_reasoning: bool,
306
307 #[serde(default = "default_true")]
309 pub stream_reasoning: bool,
310
311 pub chat_template_kwargs: Option<HashMap<String, Value>>,
313
314 #[serde(default)]
316 pub return_hidden_states: bool,
317
318 pub sampling_seed: Option<u64>,
320}
321
322fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> {
328 if messages.is_empty() {
329 return Err(validator::ValidationError::new("messages cannot be empty"));
330 }
331
332 for msg in messages.iter() {
333 if let ChatMessage::User { content, .. } = msg {
334 match content {
335 MessageContent::Text(text) if text.is_empty() => {
336 return Err(validator::ValidationError::new(
337 "message content cannot be empty",
338 ));
339 }
340 MessageContent::Parts(parts) if parts.is_empty() => {
341 return Err(validator::ValidationError::new(
342 "message content parts cannot be empty",
343 ));
344 }
345 _ => {}
346 }
347 }
348 }
349 Ok(())
350}
351
352fn validate_chat_cross_parameters(
354 req: &ChatCompletionRequest,
355) -> Result<(), validator::ValidationError> {
356 if req.top_logprobs.is_some() && !req.logprobs {
358 let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs");
359 e.message = Some("top_logprobs is only allowed when logprobs is enabled".into());
360 return Err(e);
361 }
362
363 if req.stream_options.is_some() && !req.stream {
365 let mut e = validator::ValidationError::new("stream_options_requires_stream");
366 e.message =
367 Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into());
368 return Err(e);
369 }
370
371 if let (Some(min), Some(max)) = (req.min_tokens, req.max_completion_tokens) {
373 if min > max {
374 let mut e = validator::ValidationError::new("min_tokens_exceeds_max");
375 e.message = Some("min_tokens cannot exceed max_tokens/max_completion_tokens".into());
376 return Err(e);
377 }
378 }
379
380 let has_json_format = matches!(
382 req.response_format,
383 Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
384 );
385
386 if has_json_format && req.regex.is_some() {
387 let mut e = validator::ValidationError::new("regex_conflicts_with_json");
388 e.message = Some("cannot use regex constraint with JSON response format".into());
389 return Err(e);
390 }
391
392 if has_json_format && req.ebnf.is_some() {
393 let mut e = validator::ValidationError::new("ebnf_conflicts_with_json");
394 e.message = Some("cannot use EBNF constraint with JSON response format".into());
395 return Err(e);
396 }
397
398 let constraint_count = [
400 req.regex.is_some(),
401 req.ebnf.is_some(),
402 matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })),
403 ]
404 .iter()
405 .filter(|&&x| x)
406 .count();
407
408 if constraint_count > 1 {
409 let mut e = validator::ValidationError::new("multiple_constraints");
410 e.message = Some("only one structured output constraint (regex, ebnf, or json_schema) can be active at a time".into());
411 return Err(e);
412 }
413
414 if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format {
416 if json_schema.name.is_empty() {
417 let mut e = validator::ValidationError::new("json_schema_name_empty");
418 e.message = Some("JSON schema name cannot be empty".into());
419 return Err(e);
420 }
421 }
422
423 if let Some(ref tool_choice) = req.tool_choice {
425 let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty());
426
427 let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None));
429
430 if is_some_choice && !has_tools {
431 let mut e = validator::ValidationError::new("tool_choice_requires_tools");
432 e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into());
433 return Err(e);
434 }
435
436 if has_tools {
438 let tools = req.tools.as_ref().unwrap();
439
440 match tool_choice {
441 ToolChoice::Function { function, .. } => {
442 let function_exists = tools.iter().any(|tool| {
444 tool.tool_type == "function" && tool.function.name == function.name
445 });
446
447 if !function_exists {
448 let mut e =
449 validator::ValidationError::new("tool_choice_function_not_found");
450 e.message = Some(
451 format!(
452 "Invalid value for 'tool_choice': function '{}' not found in 'tools'.",
453 function.name
454 )
455 .into(),
456 );
457 return Err(e);
458 }
459 }
460 ToolChoice::AllowedTools {
461 mode,
462 tools: allowed_tools,
463 ..
464 } => {
465 if mode != "auto" && mode != "required" {
467 let mut e = validator::ValidationError::new("tool_choice_invalid_mode");
468 e.message = Some(format!(
469 "Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{}'.",
470 mode
471 ).into());
472 return Err(e);
473 }
474
475 for tool_ref in allowed_tools {
477 match tool_ref {
478 ToolReference::Function { name } => {
479 let tool_exists = tools.iter().any(|tool| {
481 tool.tool_type == "function" && tool.function.name == *name
482 });
483
484 if !tool_exists {
485 let mut e = validator::ValidationError::new(
486 "tool_choice_tool_not_found",
487 );
488 e.message = Some(
489 format!(
490 "Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.",
491 name
492 )
493 .into(),
494 );
495 return Err(e);
496 }
497 }
498 _ => {
499 let mut e = validator::ValidationError::new(
501 "tool_choice_invalid_tool_type",
502 );
503 e.message = Some(
504 format!(
505 "Invalid value for 'tool_choice.tools': Chat Completion API only supports function tools, got '{}'.",
506 tool_ref.identifier()
507 )
508 .into(),
509 );
510 return Err(e);
511 }
512 }
513 }
514 }
515 _ => {}
516 }
517 }
518 }
519
520 Ok(())
521}
522
523impl Normalizable for ChatCompletionRequest {
528 fn normalize(&mut self) {
533 #[allow(deprecated)]
535 if self.max_completion_tokens.is_none() && self.max_tokens.is_some() {
536 self.max_completion_tokens = self.max_tokens;
537 self.max_tokens = None; }
539
540 #[allow(deprecated)]
542 if self.tools.is_none() && self.functions.is_some() {
543 tracing::warn!("functions is deprecated, use tools instead");
544 self.tools = self.functions.as_ref().map(|functions| {
545 functions
546 .iter()
547 .map(|func| Tool {
548 tool_type: "function".to_string(),
549 function: func.clone(),
550 })
551 .collect()
552 });
553 self.functions = None; }
555
556 #[allow(deprecated)]
558 if self.tool_choice.is_none() && self.function_call.is_some() {
559 tracing::warn!("function_call is deprecated, use tool_choice instead");
560 self.tool_choice = self.function_call.as_ref().map(|fc| match fc {
561 FunctionCall::None => ToolChoice::Value(ToolChoiceValue::None),
562 FunctionCall::Auto => ToolChoice::Value(ToolChoiceValue::Auto),
563 FunctionCall::Function { name } => ToolChoice::Function {
564 tool_type: "function".to_string(),
565 function: FunctionChoice { name: name.clone() },
566 },
567 });
568 self.function_call = None; }
570
571 if self.tool_choice.is_none() {
573 if let Some(tools) = &self.tools {
574 let choice_value = if !tools.is_empty() {
575 ToolChoiceValue::Auto
576 } else {
577 ToolChoiceValue::None
578 };
579 self.tool_choice = Some(ToolChoice::Value(choice_value));
580 }
581 }
583 }
584}
585
586impl GenerationRequest for ChatCompletionRequest {
591 fn is_stream(&self) -> bool {
592 self.stream
593 }
594
595 fn get_model(&self) -> Option<&str> {
596 Some(&self.model)
597 }
598
599 fn extract_text_for_routing(&self) -> String {
600 let mut buffer = String::new();
603 let mut has_content = false;
604
605 for msg in &self.messages {
606 match msg {
607 ChatMessage::System { content, .. }
608 | ChatMessage::User { content, .. }
609 | ChatMessage::Tool { content, .. }
610 | ChatMessage::Developer { content, .. } => {
611 if has_content && content.has_text() {
612 buffer.push(' ');
613 }
614 if content.append_text_to(&mut buffer) {
615 has_content = true;
616 }
617 }
618 ChatMessage::Assistant {
619 content,
620 reasoning_content,
621 ..
622 } => {
623 if let Some(c) = content {
625 if has_content && c.has_text() {
626 buffer.push(' ');
627 }
628 if c.append_text_to(&mut buffer) {
629 has_content = true;
630 }
631 }
632 if let Some(reasoning) = reasoning_content {
634 if !reasoning.is_empty() {
635 if has_content {
636 buffer.push(' ');
637 }
638 buffer.push_str(reasoning);
639 has_content = true;
640 }
641 }
642 }
643 ChatMessage::Function { content, .. } => {
644 if !content.is_empty() {
645 if has_content {
646 buffer.push(' ');
647 }
648 buffer.push_str(content);
649 has_content = true;
650 }
651 }
652 }
653 }
654
655 buffer
656 }
657}
658
659#[serde_with::skip_serializing_none]
664#[derive(Debug, Clone, Deserialize, Serialize)]
665pub struct ChatCompletionResponse {
666 pub id: String,
667 pub object: String, pub created: u64,
669 pub model: String,
670 pub choices: Vec<ChatChoice>,
671 pub usage: Option<Usage>,
672 pub system_fingerprint: Option<String>,
673}
674
675impl ChatCompletionResponse {
676 pub fn builder(
678 id: impl Into<String>,
679 model: impl Into<String>,
680 ) -> ChatCompletionResponseBuilder {
681 ChatCompletionResponseBuilder::new(id, model)
682 }
683}
684
685#[derive(Debug, Clone, Deserialize, Serialize)]
687pub struct ChatCompletionMessage {
688 pub role: String, #[serde(skip_serializing_if = "Option::is_none")]
690 pub content: Option<String>,
691 #[serde(skip_serializing_if = "Option::is_none")]
692 pub tool_calls: Option<Vec<ToolCall>>,
693 pub reasoning_content: Option<String>,
694 }
697
698#[derive(Debug, Clone, Deserialize, Serialize)]
699pub struct ChatChoice {
700 pub index: u32,
701 pub message: ChatCompletionMessage,
702 #[serde(skip_serializing_if = "Option::is_none")]
703 pub logprobs: Option<ChatLogProbs>,
704 pub finish_reason: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
707 pub matched_stop: Option<Value>, #[serde(skip_serializing_if = "Option::is_none")]
710 pub hidden_states: Option<Vec<f32>>,
711}
712
713#[serde_with::skip_serializing_none]
714#[derive(Debug, Clone, Deserialize, Serialize)]
715pub struct ChatCompletionStreamResponse {
716 pub id: String,
717 pub object: String, pub created: u64,
719 pub model: String,
720 pub system_fingerprint: Option<String>,
721 pub choices: Vec<ChatStreamChoice>,
722 pub usage: Option<Usage>,
723}
724
725impl ChatCompletionStreamResponse {
726 pub fn builder(
728 id: impl Into<String>,
729 model: impl Into<String>,
730 ) -> ChatCompletionStreamResponseBuilder {
731 ChatCompletionStreamResponseBuilder::new(id, model)
732 }
733}
734
735#[derive(Debug, Clone, Deserialize, Serialize)]
737pub struct ChatMessageDelta {
738 #[serde(skip_serializing_if = "Option::is_none")]
739 pub role: Option<String>,
740 #[serde(skip_serializing_if = "Option::is_none")]
741 pub content: Option<String>,
742 #[serde(skip_serializing_if = "Option::is_none")]
743 pub tool_calls: Option<Vec<ToolCallDelta>>,
744 pub reasoning_content: Option<String>,
745}
746
747#[derive(Debug, Clone, Deserialize, Serialize)]
748pub struct ChatStreamChoice {
749 pub index: u32,
750 pub delta: ChatMessageDelta,
751 pub logprobs: Option<ChatLogProbs>,
752 pub finish_reason: Option<String>,
753 #[serde(skip_serializing_if = "Option::is_none")]
754 pub matched_stop: Option<Value>,
755}