1use std::collections::HashMap;
2use std::fmt;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use secrecy::SecretString;
9use serde::{Deserialize, Serialize};
10
11use crate::chat::Message;
12use crate::tools::Tool;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
19#[non_exhaustive]
20pub enum ToolChoice {
21 #[serde(rename = "auto")]
25 Auto,
26 #[serde(rename = "none")]
30 None,
31 #[serde(rename = "required")]
35 Required,
36 Function {
48 name: String,
50 },
51}
52
53impl fmt::Display for ToolChoice {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 match self {
56 ToolChoice::Auto => write!(f, "auto"),
57 ToolChoice::None => write!(f, "none"),
58 ToolChoice::Required => write!(f, "required"),
59 ToolChoice::Function { name } => write!(f, "{}", name),
60 }
61 }
62}
63
64impl From<ToolChoice> for serde_json::Value {
65 fn from(tool_choice: ToolChoice) -> Self {
66 match tool_choice {
67 ToolChoice::Auto => serde_json::Value::String("auto".to_string()),
68 ToolChoice::None => serde_json::Value::String("none".to_string()),
69 ToolChoice::Required => serde_json::Value::String("required".to_string()),
70 ToolChoice::Function { name } => {
71 serde_json::json!({
72 "type": "function",
73 "function": {
74 "name": name
75 }
76 })
77 }
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy)]
87#[non_exhaustive]
88pub enum FinishReason {
89 #[serde(rename = "stop")]
93 Stop,
94 #[serde(rename = "length")]
98 Length,
99 #[serde(rename = "tool_calls")]
103 ToolCalls,
104 #[serde(rename = "content_filter")]
108 ContentFilter,
109 #[serde(rename = "model_error")]
113 ModelError,
114}
115
116impl fmt::Display for FinishReason {
117 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118 match self {
119 FinishReason::Stop => write!(f, "stop"),
120 FinishReason::Length => write!(f, "length"),
121 FinishReason::ToolCalls => write!(f, "tool_calls"),
122 FinishReason::ContentFilter => write!(f, "content_filter"),
123 FinishReason::ModelError => write!(f, "model_error"),
124 }
125 }
126}
127
128impl FromStr for FinishReason {
129 type Err = anyhow::Error;
130
131 fn from_str(s: &str) -> Result<Self, Self::Err> {
132 match s {
133 "stop" => Ok(FinishReason::Stop),
134 "length" => Ok(FinishReason::Length),
135 "tool_calls" => Ok(FinishReason::ToolCalls),
136 "content_filter" => Ok(FinishReason::ContentFilter),
137 "model_error" => Ok(FinishReason::ModelError),
138 _ => anyhow::bail!("Unknown finish reason: {}", s),
139 }
140 }
141}
142
143#[derive(Debug, Clone)]
164pub struct RetryConfig {
165 pub max_retries: usize,
167 pub initial_delay: Duration,
169 pub max_delay: Duration,
171 pub backoff_multiplier: f64,
173 pub jitter: bool,
175}
176
177impl Default for RetryConfig {
178 fn default() -> Self {
179 Self {
180 max_retries: 3,
181 initial_delay: Duration::from_millis(1000),
182 max_delay: Duration::from_secs(30),
183 backoff_multiplier: 2.0,
184 jitter: true,
185 }
186 }
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct Usage {
200 #[serde(alias = "input_tokens")]
202 pub prompt_tokens: u32,
203 #[serde(alias = "output_tokens")]
205 pub completion_tokens: u32,
206 pub total_tokens: u32,
208 pub cost: Option<f64>,
210 #[serde(skip_serializing_if = "Option::is_none")]
212 pub input_tokens_details: Option<InputTokensDetails>,
213 #[serde(skip_serializing_if = "Option::is_none")]
215 pub output_tokens_details: Option<OutputTokensDetails>,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct InputTokensDetails {
224 pub cached_tokens: u32,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct OutputTokensDetails {
236 pub reasoning_tokens: u32,
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct ChatRequest {
261 pub messages: Arc<[Message]>,
263 pub model: Option<String>,
265 pub temperature: Option<f32>,
269 pub max_tokens: Option<u32>,
271 pub top_p: Option<f32>,
275 pub frequency_penalty: Option<f32>,
279 pub presence_penalty: Option<f32>,
283 pub stop: Option<Vec<String>>,
285 pub tools: Option<Vec<Tool>>,
287 pub tool_choice: Option<ToolChoice>,
289 pub stream: bool,
291 pub user: Option<String>,
293 pub enable_thinking: Option<bool>,
295 pub metadata: HashMap<String, serde_json::Value>,
297}
298
299impl fmt::Display for ChatRequest {
300 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301 match serde_json::to_string(self) {
302 Ok(json) => write!(f, "{}", json),
303 Err(_) => write!(f, "Error serializing ChatRequest to JSON"),
304 }
305 }
306}
307
308impl ChatRequest {
309 pub fn new(messages: impl Into<Arc<[Message]>>) -> Self {
328 Self {
329 messages: messages.into(),
330 model: None,
331 temperature: None,
332 max_tokens: None,
333 top_p: None,
334 frequency_penalty: None,
335 presence_penalty: None,
336 stop: None,
337 tools: None,
338 tool_choice: None,
339 stream: false,
340 user: None,
341 enable_thinking: None,
342 metadata: HashMap::new(),
343 }
344 }
345}
346
347impl From<(&Config, Vec<Message>)> for ChatRequest {
348 fn from((config, messages): (&Config, Vec<Message>)) -> Self {
349 Self {
350 messages: messages.into(),
351 model: Some(config.model.clone()),
352 temperature: config.temperature,
353 max_tokens: config.max_tokens,
354 top_p: config.top_p,
355 frequency_penalty: config.frequency_penalty,
356 presence_penalty: config.presence_penalty,
357 stop: config.stop_sequences.clone(),
358 tools: None,
359 tool_choice: None,
360 stream: false,
361 user: None,
362 enable_thinking: None,
363 metadata: HashMap::new(),
364 }
365 }
366}
367
368impl From<(&Config, Arc<[Message]>)> for ChatRequest {
369 fn from((config, messages): (&Config, Arc<[Message]>)) -> Self {
370 Self {
371 messages,
372 model: Some(config.model.clone()),
373 temperature: config.temperature,
374 max_tokens: config.max_tokens,
375 top_p: config.top_p,
376 frequency_penalty: config.frequency_penalty,
377 presence_penalty: config.presence_penalty,
378 stop: config.stop_sequences.clone(),
379 tools: None,
380 tool_choice: None,
381 stream: false,
382 user: None,
383 enable_thinking: None,
384 metadata: HashMap::new(),
385 }
386 }
387}
388
389impl ChatRequest {
390 pub fn with_model(mut self, model: impl Into<String>) -> Self {
396 self.model = Some(model.into());
397 self
398 }
399
400 pub fn with_temperature(mut self, temperature: f32) -> Self {
408 self.temperature = Some(temperature);
409 self
410 }
411
412 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
418 self.max_tokens = Some(max_tokens);
419 self
420 }
421
422 pub fn with_top_p(mut self, top_p: f32) -> Self {
430 self.top_p = Some(top_p);
431 self
432 }
433
434 pub fn with_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
442 self.frequency_penalty = Some(frequency_penalty);
443 self
444 }
445
446 pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
454 self.presence_penalty = Some(presence_penalty);
455 self
456 }
457
458 pub fn with_stop_sequences(
464 mut self,
465 stop_sequences: impl IntoIterator<Item = impl Into<String>>,
466 ) -> Self {
467 self.stop = Some(stop_sequences.into_iter().map(|s| s.into()).collect());
468 self
469 }
470
471 pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
477 self.tools = Some(tools);
478 self
479 }
480
481 pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
487 self.tool_choice = Some(tool_choice);
488 self
489 }
490
491 pub fn with_streaming(mut self, stream: bool) -> Self {
497 self.stream = stream;
498 self
499 }
500
501 pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
507 self.metadata = metadata;
508 self
509 }
510
511 pub fn with_thinking(mut self, enable_thinking: bool) -> Self {
517 self.enable_thinking = Some(enable_thinking);
518 self
519 }
520
521 pub fn validate_has_messages(&self) -> anyhow::Result<()> {
527 if self.messages.is_empty() {
528 anyhow::bail!("Chat request must have at least one message");
529 }
530 Ok(())
531 }
532
533 pub fn validate(&self) -> anyhow::Result<()> {
535 self.validate_has_messages()?;
536
537 if let Some(temp) = self.temperature
538 && !(0.0..=2.0).contains(&temp)
539 {
540 anyhow::bail!("Temperature must be between 0.0 and 2.0, got {}", temp);
541 }
542
543 if let Some(top_p) = self.top_p
544 && !(0.0..=1.0).contains(&top_p)
545 {
546 anyhow::bail!("top_p must be between 0.0 and 1.0, got {}", top_p);
547 }
548
549 if let Some(freq_penalty) = self.frequency_penalty
550 && !(-2.0..=2.0).contains(&freq_penalty)
551 {
552 anyhow::bail!(
553 "frequency_penalty must be between -2.0 and 2.0, got {}",
554 freq_penalty
555 );
556 }
557
558 if let Some(pres_penalty) = self.presence_penalty
559 && !(-2.0..=2.0).contains(&pres_penalty)
560 {
561 anyhow::bail!(
562 "presence_penalty must be between -2.0 and 2.0, got {}",
563 pres_penalty
564 );
565 }
566
567 Ok(())
568 }
569
570 pub fn has_tools(&self) -> bool {
576 self.tools.as_ref().is_some_and(|t| !t.is_empty())
577 }
578
579 pub fn is_streaming(&self) -> bool {
585 self.stream
586 }
587}
588
589#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct ChatResponse {
618 pub message: Message,
620 pub model: String,
622 pub usage: Option<Usage>,
624 pub finish_reason: Option<FinishReason>,
626 pub created_at: DateTime<Utc>,
628 pub response_id: Option<String>,
630 pub metadata: HashMap<String, serde_json::Value>,
632}
633
634#[derive(Debug, Clone, Serialize, Deserialize)]
640pub struct ChatChunk {
641 pub model: String,
643 pub delta_content: Option<String>,
645 pub delta_role: Option<crate::chat::MessageRole>,
647 pub delta_tool_calls: Option<Vec<crate::tools::ToolCall>>,
649 pub finish_reason: Option<FinishReason>,
651 pub usage: Option<Usage>,
653 pub response_id: Option<String>,
655 pub created_at: DateTime<Utc>,
657 pub metadata: HashMap<String, serde_json::Value>,
659}
660
661impl fmt::Display for ChatResponse {
662 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
663 match serde_json::to_string(self) {
664 Ok(json) => write!(f, "{}", json),
665 Err(_) => write!(f, "Error serializing ChatResponse to JSON"),
666 }
667 }
668}
669
670#[derive(Debug, Clone, Serialize, Deserialize)]
691pub struct Config {
692 pub provider: String,
694 pub model: String,
696 pub base_url: Option<String>,
700 #[serde(skip_serializing, default)]
704 pub api_key: Option<SecretString>,
705 pub organization: Option<String>,
707 pub timeout_seconds: Option<u64>,
709 #[serde(skip)]
711 pub retry_config: RetryConfig,
712 pub temperature: Option<f32>,
714 pub max_tokens: Option<u32>,
716 pub top_p: Option<f32>,
718 pub frequency_penalty: Option<f32>,
720 pub presence_penalty: Option<f32>,
722 pub stop_sequences: Option<Vec<String>>,
724 pub metadata: HashMap<String, serde_json::Value>,
726}
727
728impl Default for Config {
729 fn default() -> Self {
730 Self {
731 provider: "ollama".to_string(),
732 model: "gpt-oss:20b".to_string(),
733 base_url: None,
734 api_key: None,
735 organization: None,
736 timeout_seconds: None,
737 retry_config: RetryConfig::default(),
738 temperature: None,
739 max_tokens: None,
740 top_p: None,
741 frequency_penalty: None,
742 presence_penalty: None,
743 stop_sequences: None,
744 metadata: HashMap::new(),
745 }
746 }
747}
748
749impl Config {
750 pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
767 Self {
768 provider: provider.into(),
769 model: model.into(),
770 ..Default::default()
771 }
772 }
773
774 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
780 self.base_url = Some(base_url.into());
781 self
782 }
783
784 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
792 self.api_key = Some(SecretString::new(api_key.into().into()));
793 self
794 }
795
796 pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
802 self.organization = Some(organization.into());
803 self
804 }
805
806 pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
812 self.timeout_seconds = Some(timeout_seconds);
813 self
814 }
815
816 pub fn with_temperature(mut self, temperature: f32) -> Self {
822 self.temperature = Some(temperature);
823 self
824 }
825
826 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
832 self.max_tokens = Some(max_tokens);
833 self
834 }
835
836 pub fn with_top_p(mut self, top_p: f32) -> Self {
842 self.top_p = Some(top_p);
843 self
844 }
845
846 pub fn with_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
852 self.frequency_penalty = Some(frequency_penalty);
853 self
854 }
855
856 pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
862 self.presence_penalty = Some(presence_penalty);
863 self
864 }
865
866 pub fn with_stop_sequences(
872 mut self,
873 stop_sequences: impl IntoIterator<Item = impl Into<String>>,
874 ) -> Self {
875 self.stop_sequences = Some(stop_sequences.into_iter().map(|s| s.into()).collect());
876 self
877 }
878
879 pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
885 self.metadata = metadata;
886 self
887 }
888
889 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
895 self.retry_config = retry_config;
896 self
897 }
898}
899
900impl From<(Config, Vec<Message>)> for ChatRequest {
901 fn from((config, messages): (Config, Vec<Message>)) -> Self {
902 let mut request = ChatRequest::new(messages).with_model(&config.model);
903
904 if let Some(temperature) = config.temperature {
905 request = request.with_temperature(temperature);
906 }
907
908 if let Some(max_tokens) = config.max_tokens {
909 request = request.with_max_tokens(max_tokens);
910 }
911
912 if let Some(top_p) = config.top_p {
913 request.top_p = Some(top_p);
914 }
915
916 if let Some(frequency_penalty) = config.frequency_penalty {
917 request.frequency_penalty = Some(frequency_penalty);
918 }
919
920 if let Some(presence_penalty) = config.presence_penalty {
921 request.presence_penalty = Some(presence_penalty);
922 }
923
924 if let Some(stop_sequences) = config.stop_sequences {
925 request.stop = Some(stop_sequences);
926 }
927
928 request.metadata = config.metadata;
929
930 request
931 }
932}
933
934impl Config {
935 pub fn into_chat_request(self, messages: Vec<Message>) -> ChatRequest {
955 (self, messages).into()
956 }
957
958 pub fn validate(&self) -> anyhow::Result<()> {
970 if let Some(temp) = self.temperature
971 && !(0.0..=2.0).contains(&temp)
972 {
973 anyhow::bail!("Temperature must be between 0.0 and 2.0, got {}", temp);
974 }
975
976 if let Some(top_p) = self.top_p
977 && !(0.0..=1.0).contains(&top_p)
978 {
979 anyhow::bail!("top_p must be between 0.0 and 1.0, got {}", top_p);
980 }
981
982 if let Some(freq_penalty) = self.frequency_penalty
983 && !(-2.0..=2.0).contains(&freq_penalty)
984 {
985 anyhow::bail!(
986 "frequency_penalty must be between -2.0 and 2.0, got {}",
987 freq_penalty
988 );
989 }
990
991 if let Some(pres_penalty) = self.presence_penalty
992 && !(-2.0..=2.0).contains(&pres_penalty)
993 {
994 anyhow::bail!(
995 "presence_penalty must be between -2.0 and 2.0, got {}",
996 pres_penalty
997 );
998 }
999
1000 Ok(())
1001 }
1002}
1003
1004#[cfg(test)]
1005mod proptests {
1006 use super::*;
1007 use proptest::prelude::*;
1008
1009 proptest! {
1010 #[test]
1011 fn temperature_validation(temp in -10.0f32..10.0f32) {
1012 let config = Config::new("openai", "gpt-4").with_temperature(temp);
1013 let is_valid = (0.0..=2.0).contains(&temp);
1014 assert_eq!(config.validate().is_ok(), is_valid);
1015 }
1016
1017 #[test]
1018 fn top_p_validation(top_p in -5.0f32..5.0f32) {
1019 let config = Config::new("openai", "gpt-4").with_top_p(top_p);
1020 let is_valid = (0.0..=1.0).contains(&top_p);
1021 assert_eq!(config.validate().is_ok(), is_valid);
1022 }
1023
1024 #[test]
1025 fn frequency_penalty_validation(penalty in -10.0f32..10.0f32) {
1026 let config = Config::new("openai", "gpt-4").with_frequency_penalty(penalty);
1027 let is_valid = (-2.0..=2.0).contains(&penalty);
1028 assert_eq!(config.validate().is_ok(), is_valid);
1029 }
1030
1031 #[test]
1032 fn presence_penalty_validation(penalty in -10.0f32..10.0f32) {
1033 let config = Config::new("openai", "gpt-4").with_presence_penalty(penalty);
1034 let is_valid = (-2.0..=2.0).contains(&penalty);
1035 assert_eq!(config.validate().is_ok(), is_valid);
1036 }
1037
1038 #[test]
1039 fn max_tokens_validation(tokens in 0u32..1000000u32) {
1040 let config = Config::new("openai", "gpt-4").with_max_tokens(tokens);
1041 assert!(config.validate().is_ok());
1043 }
1044
1045 #[test]
1046 fn config_builder_with_string_slice(
1047 provider in ".*",
1048 model in ".*",
1049 base_url in ".*",
1050 ) {
1051 let config = Config::new(provider.as_str(), model.as_str())
1052 .with_base_url(base_url.as_str());
1053
1054 assert_eq!(config.provider, provider);
1056 assert_eq!(config.model, model);
1057 assert_eq!(config.base_url, Some(base_url));
1058 }
1059
1060 #[test]
1061 fn config_builder_with_owned_string(
1062 provider in ".*",
1063 model in ".*",
1064 ) {
1065 let config = Config::new(provider.clone(), model.clone());
1066
1067 assert_eq!(config.provider, provider);
1069 assert_eq!(config.model, model);
1070 }
1071
1072 #[test]
1073 fn stop_sequences_accepts_various_types(
1074 sequences in prop::collection::vec(".*", 0..10),
1075 ) {
1076 let config1 = Config::new("openai", "gpt-4")
1078 .with_stop_sequences(sequences.clone());
1079 assert_eq!(config1.stop_sequences, Some(sequences.clone()));
1080
1081 let str_refs: Vec<&str> = sequences.iter().map(|s| s.as_str()).collect();
1083 let config2 = Config::new("openai", "gpt-4")
1084 .with_stop_sequences(str_refs);
1085 assert_eq!(config2.stop_sequences, Some(sequences.clone()));
1086
1087 if sequences.len() <= 3 {
1089 let arr: Vec<&str> = sequences.iter().map(|s| s.as_str()).collect();
1090 let config3 = Config::new("openai", "gpt-4")
1091 .with_stop_sequences(arr);
1092 assert_eq!(config3.stop_sequences, Some(sequences));
1093 }
1094 }
1095
1096 #[test]
1097 fn builder_chain_preserves_all_values(
1098 provider in ".*",
1099 model in ".*",
1100 temp in 0.0f32..2.0f32,
1101 max_tokens in 0u32..100000u32,
1102 ) {
1103 let config = Config::new(provider.as_str(), model.as_str())
1104 .with_temperature(temp)
1105 .with_max_tokens(max_tokens);
1106
1107 assert_eq!(config.provider, provider);
1108 assert_eq!(config.model, model);
1109 assert_eq!(config.temperature, Some(temp));
1110 assert_eq!(config.max_tokens, Some(max_tokens));
1111 assert!(config.validate().is_ok());
1112 }
1113
1114 #[test]
1116 fn chat_request_temperature_validation(
1117 temp in -10.0f32..10.0f32,
1118 msg_count in 1usize..10,
1119 ) {
1120 use crate::chat::{Message, MessageRole};
1121 use uuid::Uuid;
1122
1123 let messages: Vec<Message> = (0..msg_count)
1124 .map(|i| Message::new(Uuid::new_v4(), MessageRole::User, format!("message {}", i)))
1125 .collect();
1126
1127 let request = ChatRequest::new(messages).with_temperature(temp);
1128 let is_valid = (0.0..=2.0).contains(&temp);
1129 assert_eq!(request.validate().is_ok(), is_valid);
1130 }
1131
1132 #[test]
1133 fn chat_request_with_string_types(
1134 model in ".*",
1135 ) {
1136 use crate::chat::{Message, MessageRole};
1137 use uuid::Uuid;
1138
1139 let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
1140
1141 let request1 = ChatRequest::new(vec![msg.clone()])
1143 .with_model(model.as_str());
1144 assert_eq!(request1.model, Some(model.clone()));
1145
1146 let request2 = ChatRequest::new(vec![msg])
1148 .with_model(model.clone());
1149 assert_eq!(request2.model, Some(model));
1150 }
1151
1152 #[test]
1153 fn chat_request_stop_sequences_ergonomics(
1154 sequences in prop::collection::vec(".*", 1..5),
1155 ) {
1156 use crate::chat::{Message, MessageRole};
1157 use uuid::Uuid;
1158
1159 let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
1160
1161 let request1 = ChatRequest::new(vec![msg.clone()])
1163 .with_stop_sequences(sequences.clone());
1164 assert_eq!(request1.stop, Some(sequences.clone()));
1165
1166 let str_refs: Vec<&str> = sequences.iter().map(|s| s.as_str()).collect();
1168 let request2 = ChatRequest::new(vec![msg])
1169 .with_stop_sequences(str_refs);
1170 assert_eq!(request2.stop, Some(sequences));
1171 }
1172
1173 #[test]
1174 fn chat_request_builder_chain(
1175 model in ".*",
1176 temp in 0.0f32..2.0f32,
1177 max_tokens in 0u32..100000u32,
1178 top_p in 0.0f32..1.0f32,
1179 ) {
1180 use crate::chat::{Message, MessageRole};
1181 use uuid::Uuid;
1182
1183 let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
1184 let request = ChatRequest::new(vec![msg])
1185 .with_model(model.as_str())
1186 .with_temperature(temp)
1187 .with_max_tokens(max_tokens)
1188 .with_top_p(top_p);
1189
1190 assert_eq!(request.model, Some(model));
1191 assert_eq!(request.temperature, Some(temp));
1192 assert_eq!(request.max_tokens, Some(max_tokens));
1193 assert_eq!(request.top_p, Some(top_p));
1194 assert!(request.validate().is_ok());
1195 }
1196 }
1197
1198 #[test]
1199 fn chat_request_validates_empty_messages() {
1200 let request = ChatRequest::new(vec![]);
1201 assert!(request.validate().is_err());
1202 assert!(request.validate_has_messages().is_err());
1203 }
1204
1205 #[test]
1206 fn chat_request_has_tools() {
1207 use crate::chat::{Message, MessageRole};
1208 use crate::tools::{Function, Tool};
1209 use uuid::Uuid;
1210
1211 let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
1212
1213 let request_no_tools = ChatRequest::new(vec![msg.clone()]);
1215 assert!(!request_no_tools.has_tools());
1216
1217 let request_empty_tools = ChatRequest::new(vec![msg.clone()]).with_tools(vec![]);
1219 assert!(!request_empty_tools.has_tools());
1220
1221 let function = Function {
1223 name: "test_function".to_string(),
1224 description: "A test function".to_string(),
1225 parameters: serde_json::json!({}),
1226 };
1227 let tool = Tool::builder().function(function).build();
1228 let request_with_tools = ChatRequest::new(vec![msg]).with_tools(vec![tool]);
1229 assert!(request_with_tools.has_tools());
1230 }
1231}