1use async_trait::async_trait;
28use serde::{Deserialize, Serialize};
29use serde_json::Value as JsonValue;
30use std::collections::HashMap;
31
32use crate::error::Result;
33
34use futures::stream::BoxStream;
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ToolDefinition {
43 #[serde(rename = "type")]
45 pub tool_type: String,
46
47 pub function: FunctionDefinition,
49}
50
51impl ToolDefinition {
52 pub fn function(
54 name: impl Into<String>,
55 description: impl Into<String>,
56 parameters: JsonValue,
57 ) -> Self {
58 Self {
59 tool_type: "function".to_string(),
60 function: FunctionDefinition {
61 name: name.into(),
62 description: description.into(),
63 parameters,
64 strict: Some(true),
65 },
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FunctionDefinition {
73 pub name: String,
75
76 pub description: String,
78
79 pub parameters: JsonValue,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub strict: Option<bool>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ToolCall {
90 pub id: String,
92
93 #[serde(rename = "type")]
95 pub call_type: String,
96
97 pub function: FunctionCall,
99}
100
101impl ToolCall {
102 pub fn parse_arguments<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
104 serde_json::from_str(&self.function.arguments).map_err(|e| {
105 crate::error::LlmError::InvalidRequest(format!("Failed to parse tool arguments: {}", e))
106 })
107 }
108
109 pub fn name(&self) -> &str {
111 &self.function.name
112 }
113
114 pub fn arguments(&self) -> &str {
116 &self.function.arguments
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct FunctionCall {
123 pub name: String,
125
126 pub arguments: String,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132#[serde(untagged)]
133pub enum ToolChoice {
134 Auto(String),
136
137 Required(String),
139
140 Function {
142 #[serde(rename = "type")]
143 choice_type: String,
144 function: ToolChoiceFunction,
145 },
146}
147
148impl ToolChoice {
149 pub fn auto() -> Self {
151 ToolChoice::Auto("auto".to_string())
152 }
153
154 pub fn required() -> Self {
156 ToolChoice::Required("required".to_string())
157 }
158
159 pub fn function(name: impl Into<String>) -> Self {
161 ToolChoice::Function {
162 choice_type: "function".to_string(),
163 function: ToolChoiceFunction { name: name.into() },
164 }
165 }
166
167 pub fn none() -> Self {
169 ToolChoice::Auto("none".to_string())
170 }
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct ToolChoiceFunction {
176 pub name: String,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ToolResult {
183 pub tool_call_id: String,
185
186 pub role: String,
188
189 pub content: String,
191}
192
193impl ToolResult {
194 pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
196 Self {
197 tool_call_id: tool_call_id.into(),
198 role: "tool".to_string(),
199 content: content.into(),
200 }
201 }
202
203 pub fn error(tool_call_id: impl Into<String>, error: impl std::fmt::Display) -> Self {
205 Self {
206 tool_call_id: tool_call_id.into(),
207 role: "tool".to_string(),
208 content: format!("Error: {}", error),
209 }
210 }
211}
212
213#[derive(Debug, Clone)]
222pub enum StreamChunk {
223 Content(String),
225
226 ThinkingContent {
231 text: String,
233 tokens_used: Option<usize>,
235 budget_total: Option<usize>,
237 },
238
239 ToolCallDelta {
241 index: usize,
243 id: Option<String>,
245 function_name: Option<String>,
247 function_arguments: Option<String>,
249 },
250
251 Finished {
255 reason: String,
257 #[allow(dead_code)]
260 ttft_ms: Option<f64>,
261 },
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct LLMResponse {
271 pub content: String,
273
274 pub prompt_tokens: usize,
276
277 pub completion_tokens: usize,
279
280 pub total_tokens: usize,
282
283 pub model: String,
285
286 pub finish_reason: Option<String>,
288
289 #[serde(default, skip_serializing_if = "Vec::is_empty")]
291 pub tool_calls: Vec<ToolCall>,
292
293 pub metadata: HashMap<String, serde_json::Value>,
295
296 #[serde(skip_serializing_if = "Option::is_none")]
298 pub cache_hit_tokens: Option<usize>,
299
300 #[serde(default, skip_serializing_if = "Option::is_none")]
310 pub thinking_tokens: Option<usize>,
311
312 #[serde(default, skip_serializing_if = "Option::is_none")]
323 pub thinking_content: Option<String>,
324}
325
326impl LLMResponse {
327 pub fn new(content: impl Into<String>, model: impl Into<String>) -> Self {
329 Self {
330 content: content.into(),
331 prompt_tokens: 0,
332 completion_tokens: 0,
333 total_tokens: 0,
334 model: model.into(),
335 finish_reason: None,
336 tool_calls: Vec::new(),
337 metadata: HashMap::new(),
338 cache_hit_tokens: None,
339 thinking_tokens: None,
340 thinking_content: None,
341 }
342 }
343
344 pub fn with_usage(mut self, prompt: usize, completion: usize) -> Self {
346 self.prompt_tokens = prompt;
347 self.completion_tokens = completion;
348 self.total_tokens = prompt + completion;
349 self
350 }
351
352 pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
354 self.finish_reason = Some(reason.into());
355 self
356 }
357
358 pub fn with_tool_calls(mut self, calls: Vec<ToolCall>) -> Self {
360 self.tool_calls = calls;
361 self
362 }
363
364 pub fn with_cache_hit_tokens(mut self, tokens: usize) -> Self {
376 self.cache_hit_tokens = Some(tokens);
377 self
378 }
379
380 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
388 self.metadata.insert(key.into(), value);
389 self
390 }
391
392 pub fn with_thinking_tokens(mut self, tokens: usize) -> Self {
403 self.thinking_tokens = Some(tokens);
404 self
405 }
406
407 pub fn with_thinking_content(mut self, content: impl Into<String>) -> Self {
419 self.thinking_content = Some(content.into());
420 self
421 }
422
423 pub fn has_tool_calls(&self) -> bool {
425 !self.tool_calls.is_empty()
426 }
427
428 pub fn has_thinking(&self) -> bool {
432 self.thinking_tokens.is_some() || self.thinking_content.is_some()
433 }
434}
435
436#[derive(Debug, Clone, Default, Serialize, Deserialize)]
438pub struct CompletionOptions {
439 pub max_tokens: Option<usize>,
441
442 pub temperature: Option<f32>,
444
445 pub top_p: Option<f32>,
447
448 pub stop: Option<Vec<String>>,
450
451 pub frequency_penalty: Option<f32>,
453
454 pub presence_penalty: Option<f32>,
456
457 pub response_format: Option<String>,
459
460 pub system_prompt: Option<String>,
462}
463
464impl CompletionOptions {
465 pub fn with_temperature(temperature: f32) -> Self {
467 Self {
468 temperature: Some(temperature),
469 ..Default::default()
470 }
471 }
472
473 pub fn json_mode() -> Self {
475 Self {
476 response_format: Some("json_object".to_string()),
477 ..Default::default()
478 }
479 }
480}
481
482#[async_trait]
484pub trait LLMProvider: Send + Sync {
485 fn name(&self) -> &str;
487
488 fn model(&self) -> &str;
490
491 fn max_context_length(&self) -> usize;
493
494 async fn complete(&self, prompt: &str) -> Result<LLMResponse>;
496
497 async fn complete_with_options(
499 &self,
500 prompt: &str,
501 options: &CompletionOptions,
502 ) -> Result<LLMResponse>;
503
504 async fn chat(
506 &self,
507 messages: &[ChatMessage],
508 options: Option<&CompletionOptions>,
509 ) -> Result<LLMResponse>;
510
511 async fn chat_with_tools(
526 &self,
527 messages: &[ChatMessage],
528 tools: &[ToolDefinition],
529 tool_choice: Option<ToolChoice>,
530 options: Option<&CompletionOptions>,
531 ) -> Result<LLMResponse> {
532 let _ = (tools, tool_choice);
535 self.chat(messages, options).await
536 }
537
538 async fn stream(&self, _prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
540 Err(crate::error::LlmError::NotSupported(
541 "Streaming not supported".to_string(),
542 ))
543 }
544
545 async fn chat_with_tools_stream(
557 &self,
558 _messages: &[ChatMessage],
559 _tools: &[ToolDefinition],
560 _tool_choice: Option<ToolChoice>,
561 _options: Option<&CompletionOptions>,
562 ) -> Result<BoxStream<'static, Result<StreamChunk>>> {
563 Err(crate::error::LlmError::NotSupported(
564 "Streaming tool calls not supported by this provider".to_string(),
565 ))
566 }
567
568 fn supports_streaming(&self) -> bool {
570 false
571 }
572
573 fn supports_tool_streaming(&self) -> bool {
575 false
576 }
577
578 fn supports_json_mode(&self) -> bool {
580 false
581 }
582
583 fn supports_function_calling(&self) -> bool {
585 false
586 }
587
588 fn model_name(&self) -> Option<String> {
599 let m = self.model();
600 if m.is_empty() {
601 None
602 } else {
603 Some(m.to_string())
604 }
605 }
606}
607
608#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
627pub struct ImageData {
628 pub data: String,
630
631 pub mime_type: String,
633
634 #[serde(skip_serializing_if = "Option::is_none")]
639 pub detail: Option<String>,
640}
641
642impl ImageData {
643 pub fn new(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
645 Self {
646 data: data.into(),
647 mime_type: mime_type.into(),
648 detail: None,
649 }
650 }
651
652 pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
654 self.detail = Some(detail.into());
655 self
656 }
657
658 pub fn to_data_uri(&self) -> String {
662 format!("data:{};base64,{}", self.mime_type, self.data)
663 }
664
665 pub fn from_url(url: impl Into<String>) -> Self {
677 Self {
678 data: url.into(),
679 mime_type: "url".to_string(),
680 detail: None,
681 }
682 }
683
684 pub fn is_url(&self) -> bool {
686 self.mime_type == "url"
687 }
688
689 pub fn to_api_url(&self) -> String {
691 if self.is_url() {
692 self.data.clone()
693 } else {
694 self.to_data_uri()
695 }
696 }
697
698 pub fn is_supported_mime(&self) -> bool {
700 matches!(
701 self.mime_type.as_str(),
702 "image/png" | "image/jpeg" | "image/gif" | "image/webp" | "url"
703 )
704 }
705}
706
707#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
721pub struct CacheControl {
722 #[serde(rename = "type")]
724 pub cache_type: String,
725}
726
727impl CacheControl {
728 pub fn ephemeral() -> Self {
733 Self {
734 cache_type: "ephemeral".to_string(),
735 }
736 }
737}
738
739#[derive(Debug, Clone, Serialize, Deserialize)]
740pub struct ChatMessage {
741 pub role: ChatRole,
743
744 pub content: String,
746
747 #[serde(skip_serializing_if = "Option::is_none")]
749 pub name: Option<String>,
750
751 #[serde(skip_serializing_if = "Option::is_none")]
753 pub tool_calls: Option<Vec<ToolCall>>,
754
755 #[serde(skip_serializing_if = "Option::is_none")]
757 pub tool_call_id: Option<String>,
758
759 #[serde(skip_serializing_if = "Option::is_none")]
772 pub cache_control: Option<CacheControl>,
773
774 #[serde(skip_serializing_if = "Option::is_none")]
788 pub images: Option<Vec<ImageData>>,
789}
790
791impl ChatMessage {
792 pub fn system(content: impl Into<String>) -> Self {
794 Self {
795 role: ChatRole::System,
796 content: content.into(),
797 name: None,
798 tool_calls: None,
799 tool_call_id: None,
800 cache_control: None,
801 images: None,
802 }
803 }
804
805 pub fn user(content: impl Into<String>) -> Self {
807 Self {
808 role: ChatRole::User,
809 content: content.into(),
810 name: None,
811 tool_calls: None,
812 tool_call_id: None,
813 cache_control: None,
814 images: None,
815 }
816 }
817
818 pub fn user_with_images(content: impl Into<String>, images: Vec<ImageData>) -> Self {
822 Self {
823 role: ChatRole::User,
824 content: content.into(),
825 name: None,
826 tool_calls: None,
827 tool_call_id: None,
828 cache_control: None,
829 images: if images.is_empty() {
830 None
831 } else {
832 Some(images)
833 },
834 }
835 }
836
837 pub fn assistant(content: impl Into<String>) -> Self {
839 Self {
840 role: ChatRole::Assistant,
841 content: content.into(),
842 name: None,
843 tool_calls: None,
844 tool_call_id: None,
845 cache_control: None,
846 images: None,
847 }
848 }
849
850 pub fn assistant_with_tools(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
852 Self {
853 role: ChatRole::Assistant,
854 content: content.into(),
855 name: None,
856 tool_calls: if tool_calls.is_empty() {
857 None
858 } else {
859 Some(tool_calls)
860 },
861 tool_call_id: None,
862 cache_control: None,
863 images: None,
864 }
865 }
866
867 pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
869 Self {
870 role: ChatRole::Tool,
871 content: content.into(),
872 name: None,
873 tool_calls: None,
874 tool_call_id: Some(tool_call_id.into()),
875 cache_control: None,
876 images: None,
877 }
878 }
879
880 pub fn has_images(&self) -> bool {
882 self.images.as_ref().map(|v| !v.is_empty()).unwrap_or(false)
883 }
884}
885
886#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
888#[serde(rename_all = "lowercase")]
889pub enum ChatRole {
890 System,
892 User,
894 Assistant,
896 Tool,
898 Function,
900}
901
902impl ChatRole {
903 pub fn as_str(&self) -> &'static str {
905 match self {
906 ChatRole::System => "system",
907 ChatRole::User => "user",
908 ChatRole::Assistant => "assistant",
909 ChatRole::Tool => "tool",
910 ChatRole::Function => "function",
911 }
912 }
913}
914
915#[async_trait]
917pub trait EmbeddingProvider: Send + Sync {
918 fn name(&self) -> &str;
920
921 fn model(&self) -> &str;
923
924 fn dimension(&self) -> usize;
926
927 fn max_tokens(&self) -> usize;
929
930 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
932
933 async fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
935 let results = self.embed(&[text.to_string()]).await?;
936 results
937 .into_iter()
938 .next()
939 .ok_or_else(|| crate::error::LlmError::Unknown("Empty embedding result".to_string()))
940 }
941}
942
943#[cfg(test)]
944mod tests {
945 use super::*;
946
947 #[test]
948 fn test_llm_response_builder() {
949 let response = LLMResponse::new("Hello, world!", "gpt-4")
950 .with_usage(10, 5)
951 .with_finish_reason("stop");
952
953 assert_eq!(response.content, "Hello, world!");
954 assert_eq!(response.model, "gpt-4");
955 assert_eq!(response.prompt_tokens, 10);
956 assert_eq!(response.completion_tokens, 5);
957 assert_eq!(response.total_tokens, 15);
958 assert_eq!(response.finish_reason, Some("stop".to_string()));
959 }
960
961 #[test]
962 fn test_llm_response_with_cache_hit_tokens() {
963 let response = LLMResponse::new("cached response", "gemini-pro")
965 .with_usage(1000, 50)
966 .with_cache_hit_tokens(800);
967
968 assert_eq!(response.cache_hit_tokens, Some(800));
969 assert_eq!(response.prompt_tokens, 1000);
970 let cache_rate = response.cache_hit_tokens.unwrap() as f64 / response.prompt_tokens as f64;
972 assert!((cache_rate - 0.8).abs() < 0.001);
973 }
974
975 #[test]
976 fn test_llm_response_no_cache_hit_tokens() {
977 let response = LLMResponse::new("no cache", "gpt-4").with_usage(100, 20);
979
980 assert_eq!(response.cache_hit_tokens, None);
981 }
982
983 #[test]
984 fn test_chat_message_constructors() {
985 let system = ChatMessage::system("You are helpful");
986 assert_eq!(system.role, ChatRole::System);
987
988 let user = ChatMessage::user("Hello");
989 assert_eq!(user.role, ChatRole::User);
990
991 let assistant = ChatMessage::assistant("Hi there!");
992 assert_eq!(assistant.role, ChatRole::Assistant);
993 }
994
995 #[test]
996 fn test_cache_control_ephemeral() {
997 let cache = CacheControl::ephemeral();
998 assert_eq!(cache.cache_type, "ephemeral");
999 }
1000
1001 #[test]
1002 fn test_cache_control_serialization() {
1003 let cache = CacheControl::ephemeral();
1004 let json = serde_json::to_value(&cache).unwrap();
1005
1006 assert_eq!(json["type"], "ephemeral");
1008 assert!(!json.as_object().unwrap().contains_key("cache_type"));
1009 }
1010
1011 #[test]
1012 fn test_message_with_cache_control() {
1013 let mut msg = ChatMessage::system("System prompt");
1014 msg.cache_control = Some(CacheControl::ephemeral());
1015
1016 let json = serde_json::to_value(&msg).unwrap();
1017
1018 assert!(json.as_object().unwrap().contains_key("cache_control"));
1020 assert_eq!(json["cache_control"]["type"], "ephemeral");
1021 }
1022
1023 #[test]
1024 fn test_message_without_cache_control() {
1025 let msg = ChatMessage::user("Hello");
1026
1027 let json = serde_json::to_value(&msg).unwrap();
1028
1029 assert!(!json.as_object().unwrap().contains_key("cache_control"));
1031 }
1032
1033 #[test]
1034 fn test_cache_control_roundtrip() {
1035 let original = CacheControl {
1036 cache_type: "ephemeral".to_string(),
1037 };
1038
1039 let json_str = serde_json::to_string(&original).unwrap();
1041
1042 let deserialized: CacheControl = serde_json::from_str(&json_str).unwrap();
1044
1045 assert_eq!(original.cache_type, deserialized.cache_type);
1046 }
1047
1048 #[test]
1053 fn test_image_data_new() {
1054 let image = ImageData::new("iVBORw0KGgo...", "image/png");
1055 assert_eq!(image.mime_type, "image/png");
1056 assert_eq!(image.data, "iVBORw0KGgo...");
1057 assert_eq!(image.detail, None);
1058 }
1059
1060 #[test]
1061 fn test_image_data_with_detail() {
1062 let image = ImageData::new("data123", "image/jpeg").with_detail("high");
1063 assert_eq!(image.detail, Some("high".to_string()));
1064 }
1065
1066 #[test]
1067 fn test_image_data_to_data_uri() {
1068 let image = ImageData::new("base64data", "image/png");
1069 assert_eq!(image.to_data_uri(), "data:image/png;base64,base64data");
1070 }
1071
1072 #[test]
1073 fn test_image_data_supported_mime() {
1074 assert!(ImageData::new("", "image/png").is_supported_mime());
1075 assert!(ImageData::new("", "image/jpeg").is_supported_mime());
1076 assert!(ImageData::new("", "image/gif").is_supported_mime());
1077 assert!(ImageData::new("", "image/webp").is_supported_mime());
1078 assert!(!ImageData::new("", "image/bmp").is_supported_mime());
1079 assert!(!ImageData::new("", "text/plain").is_supported_mime());
1080 }
1081
1082 #[test]
1083 fn test_chat_message_user_with_images() {
1084 let images = vec![ImageData::new("data1", "image/png")];
1085 let msg = ChatMessage::user_with_images("What's this?", images);
1086
1087 assert_eq!(msg.role, ChatRole::User);
1088 assert_eq!(msg.content, "What's this?");
1089 assert!(msg.has_images());
1090 assert_eq!(msg.images.as_ref().unwrap().len(), 1);
1091 }
1092
1093 #[test]
1094 fn test_chat_message_user_with_empty_images() {
1095 let msg = ChatMessage::user_with_images("Hello", vec![]);
1096
1097 assert!(!msg.has_images());
1098 assert!(msg.images.is_none());
1099 }
1100
1101 #[test]
1102 fn test_image_data_serialization() {
1103 let image = ImageData::new("base64", "image/png").with_detail("low");
1104 let json = serde_json::to_value(&image).unwrap();
1105
1106 assert_eq!(json["data"], "base64");
1107 assert_eq!(json["mime_type"], "image/png");
1108 assert_eq!(json["detail"], "low");
1109 }
1110
1111 #[test]
1114 fn test_tool_definition_function_constructor() {
1115 let tool = ToolDefinition::function(
1116 "my_func",
1117 "Does something",
1118 serde_json::json!({"type": "object"}),
1119 );
1120 assert_eq!(tool.tool_type, "function");
1121 assert_eq!(tool.function.name, "my_func");
1122 assert_eq!(tool.function.description, "Does something");
1123 assert_eq!(tool.function.strict, Some(true));
1124 }
1125
1126 #[test]
1127 fn test_tool_definition_serialization() {
1128 let tool = ToolDefinition::function(
1129 "search",
1130 "Search the web",
1131 serde_json::json!({"type": "object", "properties": {}}),
1132 );
1133 let json = serde_json::to_value(&tool).unwrap();
1134 assert_eq!(json["type"], "function");
1135 assert_eq!(json["function"]["name"], "search");
1136 }
1137
1138 #[test]
1139 fn test_tool_call_name_and_arguments() {
1140 let tc = ToolCall {
1141 id: "call_1".to_string(),
1142 call_type: "function".to_string(),
1143 function: FunctionCall {
1144 name: "get_weather".to_string(),
1145 arguments: r#"{"city": "Paris"}"#.to_string(),
1146 },
1147 };
1148 assert_eq!(tc.name(), "get_weather");
1149 assert_eq!(tc.arguments(), r#"{"city": "Paris"}"#);
1150 }
1151
1152 #[test]
1153 fn test_tool_call_parse_arguments() {
1154 let tc = ToolCall {
1155 id: "call_2".to_string(),
1156 call_type: "function".to_string(),
1157 function: FunctionCall {
1158 name: "add".to_string(),
1159 arguments: r#"{"a": 1, "b": 2}"#.to_string(),
1160 },
1161 };
1162 let parsed: serde_json::Value = tc.parse_arguments().unwrap();
1163 assert_eq!(parsed["a"], 1);
1164 assert_eq!(parsed["b"], 2);
1165 }
1166
1167 #[test]
1168 fn test_tool_call_parse_arguments_invalid() {
1169 let tc = ToolCall {
1170 id: "call_3".to_string(),
1171 call_type: "function".to_string(),
1172 function: FunctionCall {
1173 name: "bad".to_string(),
1174 arguments: "not json".to_string(),
1175 },
1176 };
1177 let result: std::result::Result<serde_json::Value, _> = tc.parse_arguments();
1178 assert!(result.is_err());
1179 }
1180
1181 #[test]
1182 fn test_tool_choice_auto() {
1183 let tc = ToolChoice::auto();
1184 let json = serde_json::to_value(&tc).unwrap();
1185 assert_eq!(json, "auto");
1186 }
1187
1188 #[test]
1189 fn test_tool_choice_required() {
1190 let tc = ToolChoice::required();
1191 let json = serde_json::to_value(&tc).unwrap();
1192 assert_eq!(json, "required");
1193 }
1194
1195 #[test]
1196 fn test_tool_choice_none() {
1197 let tc = ToolChoice::none();
1198 let json = serde_json::to_value(&tc).unwrap();
1199 assert_eq!(json, "none");
1200 }
1201
1202 #[test]
1203 fn test_tool_choice_function() {
1204 let tc = ToolChoice::function("get_weather");
1205 if let ToolChoice::Function {
1206 choice_type,
1207 function,
1208 } = tc
1209 {
1210 assert_eq!(choice_type, "function");
1211 assert_eq!(function.name, "get_weather");
1212 } else {
1213 panic!("Expected ToolChoice::Function");
1214 }
1215 }
1216
1217 #[test]
1218 fn test_tool_result_new() {
1219 let tr = ToolResult::new("call_1", "sunny, 20C");
1220 assert_eq!(tr.tool_call_id, "call_1");
1221 assert_eq!(tr.role, "tool");
1222 assert_eq!(tr.content, "sunny, 20C");
1223 }
1224
1225 #[test]
1226 fn test_tool_result_error() {
1227 let tr = ToolResult::error("call_2", "City not found");
1228 assert_eq!(tr.tool_call_id, "call_2");
1229 assert_eq!(tr.content, "Error: City not found");
1230 }
1231
1232 #[test]
1233 fn test_llm_response_with_tool_calls() {
1234 let tc = vec![ToolCall {
1235 id: "c1".to_string(),
1236 call_type: "function".to_string(),
1237 function: FunctionCall {
1238 name: "search".to_string(),
1239 arguments: "{}".to_string(),
1240 },
1241 }];
1242 let resp = LLMResponse::new("", "gpt-4").with_tool_calls(tc);
1243 assert!(resp.has_tool_calls());
1244 assert_eq!(resp.tool_calls.len(), 1);
1245 }
1246
1247 #[test]
1248 fn test_llm_response_no_tool_calls() {
1249 let resp = LLMResponse::new("hello", "gpt-4");
1250 assert!(!resp.has_tool_calls());
1251 }
1252
1253 #[test]
1254 fn test_llm_response_with_metadata() {
1255 let resp =
1256 LLMResponse::new("hi", "gpt-4").with_metadata("id", serde_json::json!("resp_123"));
1257 assert_eq!(
1258 resp.metadata.get("id"),
1259 Some(&serde_json::json!("resp_123"))
1260 );
1261 }
1262
1263 #[test]
1264 fn test_llm_response_with_thinking() {
1265 let resp = LLMResponse::new("answer", "claude-3")
1266 .with_thinking_tokens(500)
1267 .with_thinking_content("Let me think...");
1268 assert!(resp.has_thinking());
1269 assert_eq!(resp.thinking_tokens, Some(500));
1270 assert_eq!(resp.thinking_content, Some("Let me think...".to_string()));
1271 }
1272
1273 #[test]
1274 fn test_llm_response_has_thinking_tokens_only() {
1275 let resp = LLMResponse::new("x", "o1").with_thinking_tokens(100);
1276 assert!(resp.has_thinking());
1277 }
1278
1279 #[test]
1280 fn test_llm_response_has_thinking_content_only() {
1281 let resp = LLMResponse::new("x", "claude").with_thinking_content("hmm");
1282 assert!(resp.has_thinking());
1283 }
1284
1285 #[test]
1286 fn test_llm_response_no_thinking() {
1287 let resp = LLMResponse::new("x", "gpt-4");
1288 assert!(!resp.has_thinking());
1289 }
1290
1291 #[test]
1292 fn test_completion_options_default() {
1293 let opts = CompletionOptions::default();
1294 assert!(opts.max_tokens.is_none());
1295 assert!(opts.temperature.is_none());
1296 assert!(opts.response_format.is_none());
1297 }
1298
1299 #[test]
1300 fn test_completion_options_with_temperature() {
1301 let opts = CompletionOptions::with_temperature(0.7);
1302 assert_eq!(opts.temperature, Some(0.7));
1303 assert!(opts.max_tokens.is_none());
1304 }
1305
1306 #[test]
1307 fn test_completion_options_json_mode() {
1308 let opts = CompletionOptions::json_mode();
1309 assert_eq!(opts.response_format, Some("json_object".to_string()));
1310 }
1311
1312 #[test]
1313 fn test_chat_role_as_str() {
1314 assert_eq!(ChatRole::System.as_str(), "system");
1315 assert_eq!(ChatRole::User.as_str(), "user");
1316 assert_eq!(ChatRole::Assistant.as_str(), "assistant");
1317 assert_eq!(ChatRole::Tool.as_str(), "tool");
1318 assert_eq!(ChatRole::Function.as_str(), "function");
1319 }
1320
1321 #[test]
1322 fn test_chat_role_serialization() {
1323 let json = serde_json::to_value(ChatRole::User).unwrap();
1324 assert_eq!(json, "user");
1325 let json = serde_json::to_value(ChatRole::Tool).unwrap();
1326 assert_eq!(json, "tool");
1327 }
1328
1329 #[test]
1330 fn test_chat_message_assistant_with_tools() {
1331 let tc = vec![ToolCall {
1332 id: "c1".to_string(),
1333 call_type: "function".to_string(),
1334 function: FunctionCall {
1335 name: "search".to_string(),
1336 arguments: "{}".to_string(),
1337 },
1338 }];
1339 let msg = ChatMessage::assistant_with_tools("I'll search", tc);
1340 assert_eq!(msg.role, ChatRole::Assistant);
1341 assert!(msg.tool_calls.is_some());
1342 assert_eq!(msg.tool_calls.as_ref().unwrap().len(), 1);
1343 }
1344
1345 #[test]
1346 fn test_chat_message_assistant_with_empty_tools() {
1347 let msg = ChatMessage::assistant_with_tools("just text", vec![]);
1348 assert!(msg.tool_calls.is_none());
1349 }
1350
1351 #[test]
1352 fn test_chat_message_tool_result() {
1353 let msg = ChatMessage::tool_result("call_1", "result data");
1354 assert_eq!(msg.role, ChatRole::Tool);
1355 assert_eq!(msg.tool_call_id, Some("call_1".to_string()));
1356 assert_eq!(msg.content, "result data");
1357 }
1358
1359 #[test]
1360 fn test_chat_message_has_images_false() {
1361 let msg = ChatMessage::user("hello");
1362 assert!(!msg.has_images());
1363 }
1364
1365 #[test]
1366 fn test_image_data_equality() {
1367 let a = ImageData::new("data", "image/png");
1368 let b = ImageData::new("data", "image/png");
1369 assert_eq!(a, b);
1370
1371 let c = ImageData::new("data", "image/jpeg");
1372 assert_ne!(a, c);
1373 }
1374
1375 #[test]
1376 fn test_stream_chunk_content() {
1377 let chunk = StreamChunk::Content("hello".to_string());
1378 if let StreamChunk::Content(text) = chunk {
1379 assert_eq!(text, "hello");
1380 } else {
1381 panic!("Expected Content");
1382 }
1383 }
1384
1385 #[test]
1386 fn test_stream_chunk_thinking() {
1387 let chunk = StreamChunk::ThinkingContent {
1388 text: "reasoning...".to_string(),
1389 tokens_used: Some(50),
1390 budget_total: Some(10000),
1391 };
1392 if let StreamChunk::ThinkingContent {
1393 text,
1394 tokens_used,
1395 budget_total,
1396 } = chunk
1397 {
1398 assert_eq!(text, "reasoning...");
1399 assert_eq!(tokens_used, Some(50));
1400 assert_eq!(budget_total, Some(10000));
1401 }
1402 }
1403
1404 #[test]
1405 fn test_stream_chunk_finished() {
1406 let chunk = StreamChunk::Finished {
1407 reason: "stop".to_string(),
1408 ttft_ms: Some(120.5),
1409 };
1410 if let StreamChunk::Finished { reason, ttft_ms } = chunk {
1411 assert_eq!(reason, "stop");
1412 assert_eq!(ttft_ms, Some(120.5));
1413 }
1414 }
1415
1416 #[test]
1417 fn test_stream_chunk_tool_call_delta() {
1418 let chunk = StreamChunk::ToolCallDelta {
1419 index: 0,
1420 id: Some("call_1".to_string()),
1421 function_name: Some("search".to_string()),
1422 function_arguments: Some(r#"{"q":"#.to_string()),
1423 };
1424 if let StreamChunk::ToolCallDelta {
1425 index,
1426 id,
1427 function_name,
1428 function_arguments,
1429 } = chunk
1430 {
1431 assert_eq!(index, 0);
1432 assert_eq!(id, Some("call_1".to_string()));
1433 assert_eq!(function_name, Some("search".to_string()));
1434 assert!(function_arguments.is_some());
1435 }
1436 }
1437}