1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::{
8 common::{
9 default_true, deserialize_null_as_false, validate_stop, ChatLogProbs, ContentPart,
10 Function, FunctionCall, FunctionChoice, GenerationRequest, ResponseFormat, StreamOptions,
11 StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, ToolReference,
12 Usage,
13 },
14 sampling_params::{validate_top_k_value, validate_top_p_value},
15};
16use crate::{
17 builders::{ChatCompletionResponseBuilder, ChatCompletionStreamResponseBuilder},
18 validated::Normalizable,
19};
20
21#[serde_with::skip_serializing_none]
26#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
27#[serde(tag = "role")]
28pub enum ChatMessage {
29 #[serde(rename = "system")]
30 System {
31 content: MessageContent,
32 name: Option<String>,
33 },
34 #[serde(rename = "user")]
35 User {
36 content: MessageContent,
37 name: Option<String>,
38 },
39 #[serde(rename = "assistant")]
40 Assistant {
41 content: Option<MessageContent>,
42 name: Option<String>,
43 tool_calls: Option<Vec<ToolCall>>,
44 reasoning_content: Option<String>,
46 },
47 #[serde(rename = "tool")]
48 Tool {
49 content: MessageContent,
50 tool_call_id: String,
51 },
52 #[serde(rename = "function")]
53 Function { content: String, name: String },
54 #[serde(rename = "developer")]
55 Developer {
56 content: MessageContent,
57 tools: Option<Vec<Tool>>,
58 name: Option<String>,
59 },
60}
61
62#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)]
63#[serde(untagged)]
64pub enum MessageContent {
65 Text(String),
66 Parts(Vec<ContentPart>),
67}
68
69impl MessageContent {
70 pub fn to_simple_string(&self) -> String {
75 match self {
76 MessageContent::Text(text) => text.clone(),
77 MessageContent::Parts(parts) => {
78 let mut result = String::new();
80 let mut first = true;
81 for part in parts {
82 if let ContentPart::Text { text } = part {
83 if !first {
84 result.push(' ');
85 }
86 result.push_str(text);
87 first = false;
88 }
89 }
90 result
91 }
92 }
93 }
94
95 #[inline]
98 pub fn append_text_to(&self, buffer: &mut String) -> bool {
99 match self {
100 MessageContent::Text(text) => {
101 if text.is_empty() {
102 false
103 } else {
104 buffer.push_str(text);
105 true
106 }
107 }
108 MessageContent::Parts(parts) => {
109 let mut appended = false;
110 for part in parts {
111 if let ContentPart::Text { text } = part {
112 if !text.is_empty() {
113 if appended {
114 buffer.push(' ');
115 }
116 buffer.push_str(text);
117 appended = true;
118 }
119 }
120 }
121 appended
122 }
123 }
124 }
125
126 #[inline]
128 pub fn has_text(&self) -> bool {
129 match self {
130 MessageContent::Text(text) => !text.is_empty(),
131 MessageContent::Parts(parts) => parts
132 .iter()
133 .any(|part| matches!(part, ContentPart::Text { text } if !text.is_empty())),
134 }
135 }
136}
137
138#[serde_with::skip_serializing_none]
143#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate, schemars::JsonSchema)]
144#[validate(schema(function = "validate_chat_cross_parameters"))]
145pub struct ChatCompletionRequest {
146 #[validate(custom(function = "validate_messages"))]
148 pub messages: Vec<ChatMessage>,
149
150 pub model: String,
152
153 #[validate(range(min = -2.0, max = 2.0))]
155 pub frequency_penalty: Option<f32>,
156
157 #[deprecated(note = "Use tool_choice instead")]
159 pub function_call: Option<FunctionCall>,
160
161 #[deprecated(note = "Use tools instead")]
163 pub functions: Option<Vec<Function>>,
164
165 pub logit_bias: Option<HashMap<String, f32>>,
167
168 #[serde(default, deserialize_with = "deserialize_null_as_false")]
170 pub logprobs: bool,
171
172 #[deprecated(note = "Use max_completion_tokens instead")]
174 #[validate(range(min = 1))]
175 pub max_tokens: Option<u32>,
176
177 #[validate(range(min = 1))]
179 pub max_completion_tokens: Option<u32>,
180
181 pub metadata: Option<HashMap<String, String>>,
183
184 pub modalities: Option<Vec<String>>,
186
187 #[validate(range(min = 1, max = 10))]
189 pub n: Option<u32>,
190
191 pub parallel_tool_calls: Option<bool>,
193
194 #[validate(range(min = -2.0, max = 2.0))]
196 pub presence_penalty: Option<f32>,
197
198 pub prompt_cache_key: Option<String>,
200
201 pub reasoning_effort: Option<String>,
203
204 pub response_format: Option<ResponseFormat>,
206
207 pub safety_identifier: Option<String>,
209
210 #[deprecated(note = "This feature is in Legacy mode")]
212 pub seed: Option<i64>,
213
214 pub service_tier: Option<String>,
216
217 #[validate(custom(function = "validate_stop"))]
219 pub stop: Option<StringOrArray>,
220
221 #[serde(default, deserialize_with = "deserialize_null_as_false")]
223 pub stream: bool,
224
225 pub stream_options: Option<StreamOptions>,
227
228 #[validate(range(min = 0.0, max = 2.0))]
230 pub temperature: Option<f32>,
231
232 pub tool_choice: Option<ToolChoice>,
234
235 pub tools: Option<Vec<Tool>>,
237
238 #[validate(range(min = 0, max = 20))]
240 pub top_logprobs: Option<u32>,
241
242 #[validate(custom(function = "validate_top_p_value"))]
244 pub top_p: Option<f32>,
245
246 pub verbosity: Option<i32>,
248
249 #[validate(custom(function = "validate_top_k_value"))]
257 pub top_k: Option<i32>,
258
259 #[validate(range(min = 0.0, max = 1.0))]
261 pub min_p: Option<f32>,
262
263 #[validate(range(min = 1))]
265 pub min_tokens: Option<u32>,
266
267 #[validate(range(min = 0.0, max = 2.0))]
269 pub repetition_penalty: Option<f32>,
270
271 pub regex: Option<String>,
273
274 pub ebnf: Option<String>,
276
277 pub stop_token_ids: Option<Vec<u32>>,
279
280 #[serde(default)]
282 pub no_stop_trim: bool,
283
284 #[serde(default)]
286 pub ignore_eos: bool,
287
288 #[serde(default)]
290 pub continue_final_message: bool,
291
292 #[serde(default = "default_true")]
294 pub skip_special_tokens: bool,
295
296 pub lora_path: Option<String>,
298
299 pub session_params: Option<HashMap<String, Value>>,
301
302 #[serde(default = "default_true")]
304 pub separate_reasoning: bool,
305
306 #[serde(default = "default_true")]
308 pub stream_reasoning: bool,
309
310 pub chat_template_kwargs: Option<HashMap<String, Value>>,
312
313 #[serde(default)]
315 pub return_hidden_states: bool,
316
317 pub sampling_seed: Option<u64>,
319}
320
321fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> {
327 if messages.is_empty() {
328 return Err(validator::ValidationError::new("messages cannot be empty"));
329 }
330
331 for msg in messages {
332 if let ChatMessage::User { content, .. } = msg {
333 match content {
334 MessageContent::Text(text) if text.is_empty() => {
335 return Err(validator::ValidationError::new(
336 "message content cannot be empty",
337 ));
338 }
339 MessageContent::Parts(parts) if parts.is_empty() => {
340 return Err(validator::ValidationError::new(
341 "message content parts cannot be empty",
342 ));
343 }
344 _ => {}
345 }
346 }
347 }
348 Ok(())
349}
350
351fn validate_chat_cross_parameters(
353 req: &ChatCompletionRequest,
354) -> Result<(), validator::ValidationError> {
355 if req.top_logprobs.is_some() && !req.logprobs {
357 let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs");
358 e.message = Some("top_logprobs is only allowed when logprobs is enabled".into());
359 return Err(e);
360 }
361
362 if req.stream_options.is_some() && !req.stream {
364 let mut e = validator::ValidationError::new("stream_options_requires_stream");
365 e.message =
366 Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into());
367 return Err(e);
368 }
369
370 if let (Some(min), Some(max)) = (req.min_tokens, req.max_completion_tokens) {
372 if min > max {
373 let mut e = validator::ValidationError::new("min_tokens_exceeds_max");
374 e.message = Some("min_tokens cannot exceed max_tokens/max_completion_tokens".into());
375 return Err(e);
376 }
377 }
378
379 let has_json_format = matches!(
381 req.response_format,
382 Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
383 );
384
385 if has_json_format && req.regex.is_some() {
386 let mut e = validator::ValidationError::new("regex_conflicts_with_json");
387 e.message = Some("cannot use regex constraint with JSON response format".into());
388 return Err(e);
389 }
390
391 if has_json_format && req.ebnf.is_some() {
392 let mut e = validator::ValidationError::new("ebnf_conflicts_with_json");
393 e.message = Some("cannot use EBNF constraint with JSON response format".into());
394 return Err(e);
395 }
396
397 let constraint_count = [
399 req.regex.is_some(),
400 req.ebnf.is_some(),
401 matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })),
402 ]
403 .iter()
404 .filter(|&&x| x)
405 .count();
406
407 if constraint_count > 1 {
408 let mut e = validator::ValidationError::new("multiple_constraints");
409 e.message = Some("only one structured output constraint (regex, ebnf, or json_schema) can be active at a time".into());
410 return Err(e);
411 }
412
413 if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format {
415 if json_schema.name.is_empty() {
416 let mut e = validator::ValidationError::new("json_schema_name_empty");
417 e.message = Some("JSON schema name cannot be empty".into());
418 return Err(e);
419 }
420 }
421
422 if let Some(ref tool_choice) = req.tool_choice {
424 let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty());
425
426 let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None));
428
429 if is_some_choice && !has_tools {
430 let mut e = validator::ValidationError::new("tool_choice_requires_tools");
431 e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into());
432 return Err(e);
433 }
434
435 if let Some(tools) = req.tools.as_ref().filter(|t| !t.is_empty()) {
437 match tool_choice {
438 ToolChoice::Function { function, .. } => {
439 let function_exists = tools.iter().any(|tool| {
441 tool.tool_type == "function" && tool.function.name == function.name
442 });
443
444 if !function_exists {
445 let mut e =
446 validator::ValidationError::new("tool_choice_function_not_found");
447 e.message = Some(
448 format!(
449 "Invalid value for 'tool_choice': function '{}' not found in 'tools'.",
450 function.name
451 )
452 .into(),
453 );
454 return Err(e);
455 }
456 }
457 ToolChoice::AllowedTools {
458 mode,
459 tools: allowed_tools,
460 ..
461 } => {
462 if mode != "auto" && mode != "required" {
464 let mut e = validator::ValidationError::new("tool_choice_invalid_mode");
465 e.message = Some(format!(
466 "Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{mode}'."
467 ).into());
468 return Err(e);
469 }
470
471 for tool_ref in allowed_tools {
473 match tool_ref {
474 ToolReference::Function { name } => {
475 let tool_exists = tools.iter().any(|tool| {
477 tool.tool_type == "function" && tool.function.name == *name
478 });
479
480 if !tool_exists {
481 let mut e = validator::ValidationError::new(
482 "tool_choice_tool_not_found",
483 );
484 e.message = Some(
485 format!(
486 "Invalid value for 'tool_choice.tools': tool '{name}' not found in 'tools'."
487 )
488 .into(),
489 );
490 return Err(e);
491 }
492 }
493 _ => {
494 let mut e = validator::ValidationError::new(
496 "tool_choice_invalid_tool_type",
497 );
498 e.message = Some(
499 format!(
500 "Invalid value for 'tool_choice.tools': Chat Completion API only supports function tools, got '{}'.",
501 tool_ref.identifier()
502 )
503 .into(),
504 );
505 return Err(e);
506 }
507 }
508 }
509 }
510 ToolChoice::Value(_) => {}
511 }
512 }
513 }
514
515 Ok(())
516}
517
518impl Normalizable for ChatCompletionRequest {
523 fn normalize(&mut self) {
528 #[expect(deprecated)]
530 if self.max_completion_tokens.is_none() && self.max_tokens.is_some() {
531 self.max_completion_tokens = self.max_tokens;
532 self.max_tokens = None; }
534
535 #[expect(deprecated)]
537 if self.tools.is_none() && self.functions.is_some() {
538 tracing::warn!("functions is deprecated, use tools instead");
539 self.tools = self.functions.as_ref().map(|functions| {
540 functions
541 .iter()
542 .map(|func| Tool {
543 tool_type: "function".to_string(),
544 function: func.clone(),
545 })
546 .collect()
547 });
548 self.functions = None; }
550
551 #[expect(deprecated)]
553 if self.tool_choice.is_none() && self.function_call.is_some() {
554 tracing::warn!("function_call is deprecated, use tool_choice instead");
555 self.tool_choice = self.function_call.as_ref().map(|fc| match fc {
556 FunctionCall::None => ToolChoice::Value(ToolChoiceValue::None),
557 FunctionCall::Auto => ToolChoice::Value(ToolChoiceValue::Auto),
558 FunctionCall::Function { name } => ToolChoice::Function {
559 tool_type: "function".to_string(),
560 function: FunctionChoice { name: name.clone() },
561 },
562 });
563 self.function_call = None; }
565
566 if self.tool_choice.is_none() {
568 if let Some(tools) = &self.tools {
569 let choice_value = if tools.is_empty() {
570 ToolChoiceValue::None
571 } else {
572 ToolChoiceValue::Auto
573 };
574 self.tool_choice = Some(ToolChoice::Value(choice_value));
575 }
576 }
578 }
579}
580
581impl GenerationRequest for ChatCompletionRequest {
586 fn is_stream(&self) -> bool {
587 self.stream
588 }
589
590 fn get_model(&self) -> Option<&str> {
591 Some(&self.model)
592 }
593
594 fn extract_text_for_routing(&self) -> String {
595 let mut buffer = String::new();
598 let mut has_content = false;
599
600 for msg in &self.messages {
601 match msg {
602 ChatMessage::System { content, .. }
603 | ChatMessage::User { content, .. }
604 | ChatMessage::Tool { content, .. }
605 | ChatMessage::Developer { content, .. } => {
606 if has_content && content.has_text() {
607 buffer.push(' ');
608 }
609 if content.append_text_to(&mut buffer) {
610 has_content = true;
611 }
612 }
613 ChatMessage::Assistant {
614 content,
615 reasoning_content,
616 ..
617 } => {
618 if let Some(c) = content {
620 if has_content && c.has_text() {
621 buffer.push(' ');
622 }
623 if c.append_text_to(&mut buffer) {
624 has_content = true;
625 }
626 }
627 if let Some(reasoning) = reasoning_content {
629 if !reasoning.is_empty() {
630 if has_content {
631 buffer.push(' ');
632 }
633 buffer.push_str(reasoning);
634 has_content = true;
635 }
636 }
637 }
638 ChatMessage::Function { content, .. } => {
639 if !content.is_empty() {
640 if has_content {
641 buffer.push(' ');
642 }
643 buffer.push_str(content);
644 has_content = true;
645 }
646 }
647 }
648 }
649
650 buffer
651 }
652}
653
654#[serde_with::skip_serializing_none]
659#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
660pub struct ChatCompletionResponse {
661 pub id: String,
662 pub object: String, pub created: u64,
664 pub model: String,
665 pub choices: Vec<ChatChoice>,
666 pub usage: Option<Usage>,
667 pub system_fingerprint: Option<String>,
668}
669
670impl ChatCompletionResponse {
671 pub fn builder(
673 id: impl Into<String>,
674 model: impl Into<String>,
675 ) -> ChatCompletionResponseBuilder {
676 ChatCompletionResponseBuilder::new(id, model)
677 }
678}
679
680#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
682pub struct ChatCompletionMessage {
683 pub role: String, #[serde(skip_serializing_if = "Option::is_none")]
685 pub content: Option<String>,
686 #[serde(skip_serializing_if = "Option::is_none")]
687 pub tool_calls: Option<Vec<ToolCall>>,
688 pub reasoning_content: Option<String>,
689 }
692
693#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
694pub struct ChatChoice {
695 pub index: u32,
696 pub message: ChatCompletionMessage,
697 #[serde(skip_serializing_if = "Option::is_none")]
698 pub logprobs: Option<ChatLogProbs>,
699 pub finish_reason: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
702 pub matched_stop: Option<Value>, #[serde(skip_serializing_if = "Option::is_none")]
705 pub hidden_states: Option<Vec<f32>>,
706}
707
708#[serde_with::skip_serializing_none]
709#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
710pub struct ChatCompletionStreamResponse {
711 pub id: String,
712 pub object: String, pub created: u64,
714 pub model: String,
715 pub system_fingerprint: Option<String>,
716 pub choices: Vec<ChatStreamChoice>,
717 pub usage: Option<Usage>,
718}
719
720impl ChatCompletionStreamResponse {
721 pub fn builder(
723 id: impl Into<String>,
724 model: impl Into<String>,
725 ) -> ChatCompletionStreamResponseBuilder {
726 ChatCompletionStreamResponseBuilder::new(id, model)
727 }
728}
729
730#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
732pub struct ChatMessageDelta {
733 #[serde(skip_serializing_if = "Option::is_none")]
734 pub role: Option<String>,
735 #[serde(skip_serializing_if = "Option::is_none")]
736 pub content: Option<String>,
737 #[serde(skip_serializing_if = "Option::is_none")]
738 pub tool_calls: Option<Vec<ToolCallDelta>>,
739 pub reasoning_content: Option<String>,
740}
741
742#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
743pub struct ChatStreamChoice {
744 pub index: u32,
745 pub delta: ChatMessageDelta,
746 pub logprobs: Option<ChatLogProbs>,
747 pub finish_reason: Option<String>,
748 #[serde(skip_serializing_if = "Option::is_none")]
749 pub matched_stop: Option<Value>,
750}