1use std::collections::HashMap;
2use std::fmt;
3
4use crate::{error::LLMError, ToolCall};
5use async_trait::async_trait;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use strum_macros::Display;
10
11#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Display)]
13pub enum ChatRole {
14 System,
16 User,
18 Assistant,
20 Tool,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
26#[non_exhaustive]
27pub enum ImageMime {
28 JPEG,
30 PNG,
32 GIF,
34 WEBP,
36}
37
38impl ImageMime {
39 pub fn mime_type(&self) -> &'static str {
40 match self {
41 ImageMime::JPEG => "image/jpeg",
42 ImageMime::PNG => "image/png",
43 ImageMime::GIF => "image/gif",
44 ImageMime::WEBP => "image/webp",
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize)]
51pub enum MessageType {
52 #[default]
54 Text,
55 Image((ImageMime, Vec<u8>)),
57 Pdf(Vec<u8>),
59 ImageURL(String),
61 ToolUse(Vec<ToolCall>),
63 ToolResult(Vec<ToolCall>),
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
69pub enum ReasoningEffort {
70 Low,
72 Medium,
74 High,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ChatMessage {
81 pub role: ChatRole,
83 pub message_type: MessageType,
85 pub content: String,
87}
88
89#[derive(Debug, Clone, Serialize)]
91pub struct ParameterProperty {
92 #[serde(rename = "type")]
94 pub property_type: String,
95 pub description: String,
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub items: Option<Box<ParameterProperty>>,
100 #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
102 pub enum_list: Option<Vec<String>>,
103}
104
105#[derive(Debug, Clone, Serialize)]
107pub struct ParametersSchema {
108 #[serde(rename = "type")]
110 pub schema_type: String,
111 pub properties: HashMap<String, ParameterProperty>,
113 pub required: Vec<String>,
115}
116
117#[derive(Debug, Clone, Serialize)]
127pub struct FunctionTool {
128 pub name: String,
130 pub description: String,
132 pub parameters: Value,
134}
135
136#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
173
174pub struct StructuredOutputFormat {
175 pub name: String,
177 pub description: Option<String>,
179 pub schema: Option<Value>,
181 pub strict: Option<bool>,
183}
184
185#[derive(Debug, Clone, Serialize)]
187pub struct Tool {
188 #[serde(rename = "type")]
190 pub tool_type: String,
191 pub function: FunctionTool,
193}
194
195#[derive(Debug, Clone, Default)]
198pub enum ToolChoice {
199 Any,
202
203 #[default]
206 Auto,
207
208 Tool(String),
212
213 None,
216}
217
218impl Serialize for ToolChoice {
219 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
220 where
221 S: serde::Serializer,
222 {
223 match self {
224 ToolChoice::Any => serializer.serialize_str("required"),
225 ToolChoice::Auto => serializer.serialize_str("auto"),
226 ToolChoice::None => serializer.serialize_str("none"),
227 ToolChoice::Tool(name) => {
228 use serde::ser::SerializeMap;
229
230 let mut map = serializer.serialize_map(Some(2))?;
232 map.serialize_entry("type", "function")?;
233
234 let mut function_obj = std::collections::HashMap::new();
236 function_obj.insert("name", name.as_str());
237
238 map.serialize_entry("function", &function_obj)?;
239 map.end()
240 }
241 }
242 }
243}
244
245#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
247pub struct CompletionTokensDetails {
248 #[serde(skip_serializing_if = "Option::is_none")]
250 pub reasoning_tokens: Option<u32>,
251 #[serde(skip_serializing_if = "Option::is_none")]
253 pub audio_tokens: Option<u32>,
254}
255
256#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
258pub struct PromptTokensDetails {
259 #[serde(skip_serializing_if = "Option::is_none")]
261 pub cached_tokens: Option<u32>,
262 #[serde(skip_serializing_if = "Option::is_none")]
264 pub audio_tokens: Option<u32>,
265}
266
267#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
269pub struct Usage {
270 pub prompt_tokens: u32,
272 pub completion_tokens: u32,
274 pub total_tokens: u32,
276 #[serde(skip_serializing_if = "Option::is_none")]
278 pub completion_tokens_details: Option<CompletionTokensDetails>,
279 #[serde(skip_serializing_if = "Option::is_none")]
281 pub prompt_tokens_details: Option<PromptTokensDetails>,
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct StreamResponse {
287 pub choices: Vec<StreamChoice>,
289 #[serde(skip_serializing_if = "Option::is_none")]
291 pub usage: Option<Usage>,
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct StreamChoice {
297 pub delta: StreamDelta,
299}
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct StreamToolCallFunction {
303 pub arguments: String,
304 pub name: String,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct StreamToolCallDelta {
309 pub index: usize,
310 pub function: Option<StreamToolCallFunction>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
315pub struct StreamDelta {
316 #[serde(skip_serializing_if = "Option::is_none")]
318 pub content: Option<String>,
319 #[serde(skip_serializing_if = "Option::is_none")]
320 pub tool_calls: Option<Vec<StreamToolCallDelta>>,
321}
322
323pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
324 fn text(&self) -> Option<String>;
325 fn tool_calls(&self) -> Option<Vec<ToolCall>>;
326 fn thinking(&self) -> Option<String> {
327 None
328 }
329
330 fn usage(&self) -> Option<Usage> {
331 None
332 }
333}
334
335#[async_trait]
337pub trait ChatProvider: Sync + Send {
338 async fn chat(
348 &self,
349 messages: &[ChatMessage],
350 _tools: Option<&[Tool]>,
351 json_schema: Option<StructuredOutputFormat>,
352 ) -> Result<Box<dyn ChatResponse>, LLMError>;
353
354 async fn chat_stream(
367 &self,
368 _messages: &[ChatMessage],
369 _tools: Option<&[Tool]>,
370 _json_schema: Option<StructuredOutputFormat>,
371 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
372 {
373 Err(LLMError::Generic(
374 "Streaming not supported for this provider".to_string(),
375 ))
376 }
377
378 async fn chat_stream_struct(
391 &self,
392 _messages: &[ChatMessage],
393 _tools: Option<&[Tool]>,
394 _json_schema: Option<StructuredOutputFormat>,
395 ) -> Result<
396 std::pin::Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>,
397 LLMError,
398 > {
399 Err(LLMError::Generic(
400 "Structured streaming not supported for this provider".to_string(),
401 ))
402 }
403}
404
405impl fmt::Display for ReasoningEffort {
406 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407 match self {
408 ReasoningEffort::Low => write!(f, "low"),
409 ReasoningEffort::Medium => write!(f, "medium"),
410 ReasoningEffort::High => write!(f, "high"),
411 }
412 }
413}
414
415impl ChatMessage {
416 pub fn user() -> ChatMessageBuilder {
418 ChatMessageBuilder::new(ChatRole::User)
419 }
420
421 pub fn assistant() -> ChatMessageBuilder {
423 ChatMessageBuilder::new(ChatRole::Assistant)
424 }
425}
426
427#[derive(Debug)]
429pub struct ChatMessageBuilder {
430 role: ChatRole,
431 message_type: MessageType,
432 content: String,
433}
434
435impl ChatMessageBuilder {
436 pub fn new(role: ChatRole) -> Self {
438 Self {
439 role,
440 message_type: MessageType::default(),
441 content: String::new(),
442 }
443 }
444
445 pub fn content<S: Into<String>>(mut self, content: S) -> Self {
447 self.content = content.into();
448 self
449 }
450
451 pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
453 self.message_type = MessageType::Image((image_mime, raw_bytes));
454 self
455 }
456
457 pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
459 self.message_type = MessageType::Pdf(raw_bytes);
460 self
461 }
462
463 pub fn image_url(mut self, url: impl Into<String>) -> Self {
465 self.message_type = MessageType::ImageURL(url.into());
466 self
467 }
468
469 pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
471 self.message_type = MessageType::ToolUse(tools);
472 self
473 }
474
475 pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
477 self.message_type = MessageType::ToolResult(tools);
478 self
479 }
480
481 pub fn build(self) -> ChatMessage {
483 ChatMessage {
484 role: self.role,
485 message_type: self.message_type,
486 content: self.content,
487 }
488 }
489}
490
491#[cfg(not(target_arch = "wasm32"))]
502#[allow(dead_code)]
503pub(crate) fn create_sse_stream<F>(
504 response: reqwest::Response,
505 parser: F,
506) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
507where
508 F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
509{
510 let stream = response
511 .bytes_stream()
512 .map(move |chunk| match chunk {
513 Ok(bytes) => {
514 let text = String::from_utf8_lossy(&bytes);
515 parser(&text)
516 }
517 Err(e) => Err(LLMError::HttpError(e.to_string())),
518 })
519 .filter_map(|result| async move {
520 match result {
521 Ok(Some(content)) => Some(Ok(content)),
522 Ok(None) => None,
523 Err(e) => Some(Err(e)),
524 }
525 });
526
527 Box::pin(stream)
528}
529
530#[cfg(not(target_arch = "wasm32"))]
531pub mod utils {
532 use crate::error::LLMError;
533 use reqwest::Response;
534 pub async fn check_response_status(response: Response) -> Result<Response, LLMError> {
535 if !response.status().is_success() {
536 let status = response.status();
537 let error_text = response.text().await?;
538 return Err(LLMError::ResponseFormatError {
539 message: format!("API returned error status: {status}"),
540 raw_response: error_text,
541 });
542 }
543 Ok(response)
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use serde_json::json;
551
552 #[test]
553 fn test_chat_role_serialization() {
554 let user_role = ChatRole::User;
555 let serialized = serde_json::to_string(&user_role).unwrap();
556 assert_eq!(serialized, "\"User\"");
557
558 let assistant_role = ChatRole::Assistant;
559 let serialized = serde_json::to_string(&assistant_role).unwrap();
560 assert_eq!(serialized, "\"Assistant\"");
561
562 let system_role = ChatRole::System;
563 let serialized = serde_json::to_string(&system_role).unwrap();
564 assert_eq!(serialized, "\"System\"");
565
566 let tool_role = ChatRole::Tool;
567 let serialized = serde_json::to_string(&tool_role).unwrap();
568 assert_eq!(serialized, "\"Tool\"");
569 }
570
571 #[test]
572 fn test_chat_role_deserialization() {
573 let deserialized: ChatRole = serde_json::from_str("\"User\"").unwrap();
574 assert_eq!(deserialized, ChatRole::User);
575
576 let deserialized: ChatRole = serde_json::from_str("\"Assistant\"").unwrap();
577 assert_eq!(deserialized, ChatRole::Assistant);
578
579 let deserialized: ChatRole = serde_json::from_str("\"System\"").unwrap();
580 assert_eq!(deserialized, ChatRole::System);
581
582 let deserialized: ChatRole = serde_json::from_str("\"Tool\"").unwrap();
583 assert_eq!(deserialized, ChatRole::Tool);
584 }
585
586 #[test]
587 fn test_image_mime_type() {
588 assert_eq!(ImageMime::JPEG.mime_type(), "image/jpeg");
589 assert_eq!(ImageMime::PNG.mime_type(), "image/png");
590 assert_eq!(ImageMime::GIF.mime_type(), "image/gif");
591 assert_eq!(ImageMime::WEBP.mime_type(), "image/webp");
592 }
593
594 #[test]
595 fn test_message_type_default() {
596 let default_type = MessageType::default();
597 assert_eq!(default_type, MessageType::Text);
598 }
599
600 #[test]
601 fn test_message_type_serialization() {
602 let text_type = MessageType::Text;
603 let serialized = serde_json::to_string(&text_type).unwrap();
604 assert_eq!(serialized, "\"Text\"");
605
606 let image_type = MessageType::Image((ImageMime::JPEG, vec![1, 2, 3]));
607 let serialized = serde_json::to_string(&image_type).unwrap();
608 assert!(serialized.contains("Image"));
609 }
610
611 #[test]
612 fn test_reasoning_effort_display() {
613 assert_eq!(ReasoningEffort::Low.to_string(), "low");
614 assert_eq!(ReasoningEffort::Medium.to_string(), "medium");
615 assert_eq!(ReasoningEffort::High.to_string(), "high");
616 }
617
618 #[test]
619 fn test_chat_message_builder_user() {
620 let message = ChatMessage::user().content("Hello, world!").build();
621
622 assert_eq!(message.role, ChatRole::User);
623 assert_eq!(message.content, "Hello, world!");
624 assert_eq!(message.message_type, MessageType::Text);
625 }
626
627 #[test]
628 fn test_chat_message_builder_assistant() {
629 let message = ChatMessage::assistant().content("Hi there!").build();
630
631 assert_eq!(message.role, ChatRole::Assistant);
632 assert_eq!(message.content, "Hi there!");
633 assert_eq!(message.message_type, MessageType::Text);
634 }
635
636 #[test]
637 fn test_chat_message_builder_image() {
638 let image_data = vec![1, 2, 3, 4, 5];
639 let message = ChatMessage::user()
640 .content("Check this image")
641 .image(ImageMime::PNG, image_data.clone())
642 .build();
643
644 assert_eq!(message.role, ChatRole::User);
645 assert_eq!(message.content, "Check this image");
646 assert_eq!(
647 message.message_type,
648 MessageType::Image((ImageMime::PNG, image_data))
649 );
650 }
651
652 #[test]
653 fn test_chat_message_builder_pdf() {
654 let pdf_data = vec![0x25, 0x50, 0x44, 0x46]; let message = ChatMessage::user()
656 .content("Review this PDF")
657 .pdf(pdf_data.clone())
658 .build();
659
660 assert_eq!(message.role, ChatRole::User);
661 assert_eq!(message.content, "Review this PDF");
662 assert_eq!(message.message_type, MessageType::Pdf(pdf_data));
663 }
664
665 #[test]
666 fn test_chat_message_builder_image_url() {
667 let image_url = "https://example.com/image.jpg";
668 let message = ChatMessage::user()
669 .content("See this image")
670 .image_url(image_url)
671 .build();
672
673 assert_eq!(message.role, ChatRole::User);
674 assert_eq!(message.content, "See this image");
675 assert_eq!(
676 message.message_type,
677 MessageType::ImageURL(image_url.to_string())
678 );
679 }
680
681 #[test]
682 fn test_chat_message_builder_tool_use() {
683 let tool_calls = vec![crate::ToolCall {
684 id: "call_1".to_string(),
685 call_type: "function".to_string(),
686 function: crate::FunctionCall {
687 name: "get_weather".to_string(),
688 arguments: "{\"location\": \"New York\"}".to_string(),
689 },
690 }];
691
692 let message = ChatMessage::assistant()
693 .content("Using weather tool")
694 .tool_use(tool_calls.clone())
695 .build();
696
697 assert_eq!(message.role, ChatRole::Assistant);
698 assert_eq!(message.content, "Using weather tool");
699 assert_eq!(message.message_type, MessageType::ToolUse(tool_calls));
700 }
701
702 #[test]
703 fn test_chat_message_builder_tool_result() {
704 let tool_results = vec![crate::ToolCall {
705 id: "call_1".to_string(),
706 call_type: "function".to_string(),
707 function: crate::FunctionCall {
708 name: "get_weather".to_string(),
709 arguments: "{\"temperature\": 75, \"condition\": \"sunny\"}".to_string(),
710 },
711 }];
712
713 let message = ChatMessage::user()
714 .content("Weather result")
715 .tool_result(tool_results.clone())
716 .build();
717
718 assert_eq!(message.role, ChatRole::User);
719 assert_eq!(message.content, "Weather result");
720 assert_eq!(message.message_type, MessageType::ToolResult(tool_results));
721 }
722
723 #[test]
724 fn test_structured_output_format_serialization() {
725 let format = StructuredOutputFormat {
726 name: "Person".to_string(),
727 description: Some("A person object".to_string()),
728 schema: Some(json!({
729 "type": "object",
730 "properties": {
731 "name": {"type": "string"},
732 "age": {"type": "integer"}
733 },
734 "required": ["name", "age"]
735 })),
736 strict: Some(true),
737 };
738
739 let serialized = serde_json::to_string(&format).unwrap();
740 let deserialized: StructuredOutputFormat = serde_json::from_str(&serialized).unwrap();
741
742 assert_eq!(deserialized.name, "Person");
743 assert_eq!(
744 deserialized.description,
745 Some("A person object".to_string())
746 );
747 assert_eq!(deserialized.strict, Some(true));
748 assert!(deserialized.schema.is_some());
749 }
750
751 #[test]
752 fn test_structured_output_format_equality() {
753 let format1 = StructuredOutputFormat {
754 name: "Test".to_string(),
755 description: None,
756 schema: None,
757 strict: None,
758 };
759
760 let format2 = StructuredOutputFormat {
761 name: "Test".to_string(),
762 description: None,
763 schema: None,
764 strict: None,
765 };
766
767 assert_eq!(format1, format2);
768 }
769
770 #[test]
771 fn test_tool_choice_serialization() {
772 let choice = ToolChoice::Auto;
774 let serialized = serde_json::to_string(&choice).unwrap();
775 assert_eq!(serialized, "\"auto\"");
776
777 let choice = ToolChoice::Any;
779 let serialized = serde_json::to_string(&choice).unwrap();
780 assert_eq!(serialized, "\"required\"");
781
782 let choice = ToolChoice::None;
784 let serialized = serde_json::to_string(&choice).unwrap();
785 assert_eq!(serialized, "\"none\"");
786
787 let choice = ToolChoice::Tool("my_function".to_string());
789 let serialized = serde_json::to_string(&choice).unwrap();
790 let parsed: serde_json::Value = serde_json::from_str(&serialized).unwrap();
792 assert_eq!(parsed["type"], "function");
793 assert_eq!(parsed["function"]["name"], "my_function");
794 }
795
796 #[test]
797 fn test_tool_choice_default() {
798 let default_choice = ToolChoice::default();
799 assert!(matches!(default_choice, ToolChoice::Auto));
800 }
801
802 #[test]
803 fn test_parameter_property_serialization() {
804 let property = ParameterProperty {
805 property_type: "string".to_string(),
806 description: "A test parameter".to_string(),
807 items: None,
808 enum_list: Some(vec!["option1".to_string(), "option2".to_string()]),
809 };
810
811 let serialized = serde_json::to_string(&property).unwrap();
812 assert!(serialized.contains("\"type\":\"string\""));
813 assert!(serialized.contains("\"description\":\"A test parameter\""));
814 assert!(serialized.contains("\"enum\":[\"option1\",\"option2\"]"));
815 }
816
817 #[test]
818 fn test_parameter_property_with_items() {
819 let item_property = ParameterProperty {
820 property_type: "string".to_string(),
821 description: "Array item".to_string(),
822 items: None,
823 enum_list: None,
824 };
825
826 let array_property = ParameterProperty {
827 property_type: "array".to_string(),
828 description: "An array parameter".to_string(),
829 items: Some(Box::new(item_property)),
830 enum_list: None,
831 };
832
833 let serialized = serde_json::to_string(&array_property).unwrap();
834 assert!(serialized.contains("\"type\":\"array\""));
835 assert!(serialized.contains("\"items\""));
836 }
837
838 #[test]
839 fn test_function_tool_serialization() {
840 let mut properties = HashMap::new();
841 properties.insert(
842 "name".to_string(),
843 ParameterProperty {
844 property_type: "string".to_string(),
845 description: "Name parameter".to_string(),
846 items: None,
847 enum_list: None,
848 },
849 );
850
851 let schema = ParametersSchema {
852 schema_type: "object".to_string(),
853 properties,
854 required: vec!["name".to_string()],
855 };
856
857 let function = FunctionTool {
858 name: "test_function".to_string(),
859 description: "A test function".to_string(),
860 parameters: serde_json::to_value(schema).unwrap(),
861 };
862
863 let serialized = serde_json::to_string(&function).unwrap();
864 assert!(serialized.contains("\"name\":\"test_function\""));
865 assert!(serialized.contains("\"description\":\"A test function\""));
866 assert!(serialized.contains("\"parameters\""));
867 }
868
869 #[test]
870 fn test_tool_serialization() {
871 let function = FunctionTool {
872 name: "test_tool".to_string(),
873 description: "A test tool".to_string(),
874 parameters: json!({"type": "object", "properties": {}}),
875 };
876
877 let tool = Tool {
878 tool_type: "function".to_string(),
879 function,
880 };
881
882 let serialized = serde_json::to_string(&tool).unwrap();
883 assert!(serialized.contains("\"type\":\"function\""));
884 assert!(serialized.contains("\"function\""));
885 }
886
887 #[test]
888 fn test_chat_message_serialization() {
889 let message = ChatMessage {
890 role: ChatRole::User,
891 message_type: MessageType::Text,
892 content: "Hello, world!".to_string(),
893 };
894
895 let serialized = serde_json::to_string(&message).unwrap();
896 let deserialized: ChatMessage = serde_json::from_str(&serialized).unwrap();
897
898 assert_eq!(deserialized.role, ChatRole::User);
899 assert_eq!(deserialized.message_type, MessageType::Text);
900 assert_eq!(deserialized.content, "Hello, world!");
901 }
902 }