1use std::collections::HashMap;
2use std::fmt;
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{ToolCall, error::LLMError};
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Usage {
15 #[serde(alias = "input_tokens")]
17 pub prompt_tokens: u32,
18 #[serde(alias = "output_tokens")]
20 pub completion_tokens: u32,
21 pub total_tokens: u32,
23 #[serde(
25 skip_serializing_if = "Option::is_none",
26 alias = "output_tokens_details"
27 )]
28 pub completion_tokens_details: Option<CompletionTokensDetails>,
29 #[serde(
31 skip_serializing_if = "Option::is_none",
32 alias = "input_tokens_details"
33 )]
34 pub prompt_tokens_details: Option<PromptTokensDetails>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct StreamResponse {
40 pub choices: Vec<StreamChoice>,
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub usage: Option<Usage>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct StreamChoice {
50 pub delta: StreamDelta,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct StreamDelta {
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub content: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub reasoning_content: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub tool_calls: Option<Vec<ToolCall>>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum StreamChunk {
75 Text(String),
77 ReasoningContent(String),
79
80 ToolUseStart {
82 index: usize,
84 id: String,
86 name: String,
88 },
89
90 ToolUseInputDelta {
92 index: usize,
94 partial_json: String,
96 },
97
98 ToolUseComplete {
100 index: usize,
102 tool_call: ToolCall,
104 },
105
106 Done {
108 stop_reason: String,
110 },
111 Usage(Usage),
112}
113
114#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
116pub struct CompletionTokensDetails {
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub reasoning_tokens: Option<u32>,
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub audio_tokens: Option<u32>,
123}
124
125#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
127pub struct PromptTokensDetails {
128 #[serde(skip_serializing_if = "Option::is_none")]
130 pub cached_tokens: Option<u32>,
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub audio_tokens: Option<u32>,
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
138pub enum ChatRole {
139 System,
141 User,
143 Assistant,
145 Tool,
147}
148
149impl fmt::Display for ChatRole {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 let value = match self {
152 ChatRole::System => "system",
153 ChatRole::User => "user",
154 ChatRole::Assistant => "assistant",
155 ChatRole::Tool => "tool",
156 };
157 f.write_str(value)
158 }
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
163#[non_exhaustive]
164pub enum ImageMime {
165 JPEG,
167 PNG,
169 GIF,
171 WEBP,
173}
174
175impl ImageMime {
176 pub fn mime_type(&self) -> &'static str {
177 match self {
178 ImageMime::JPEG => "image/jpeg",
179 ImageMime::PNG => "image/png",
180 ImageMime::GIF => "image/gif",
181 ImageMime::WEBP => "image/webp",
182 }
183 }
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
188pub enum MessageType {
189 #[default]
191 Text,
192 Image((ImageMime, Vec<u8>)),
194 Pdf(Vec<u8>),
196 ImageURL(String),
198 ToolUse(Vec<ToolCall>),
200 ToolResult(Vec<ToolCall>),
202}
203
204pub enum ReasoningEffort {
206 Low,
208 Medium,
210 High,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct ChatMessage {
217 pub role: ChatRole,
219 pub message_type: MessageType,
221 pub content: String,
223}
224
225#[derive(Debug, Clone, Serialize)]
227pub struct ParameterProperty {
228 #[serde(rename = "type")]
230 pub property_type: String,
231 pub description: String,
233 #[serde(skip_serializing_if = "Option::is_none")]
235 pub items: Option<Box<ParameterProperty>>,
236 #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
238 pub enum_list: Option<Vec<String>>,
239}
240
241#[derive(Debug, Clone, Serialize)]
243pub struct ParametersSchema {
244 #[serde(rename = "type")]
246 pub schema_type: String,
247 pub properties: HashMap<String, ParameterProperty>,
249 pub required: Vec<String>,
251}
252
253#[derive(Debug, Clone, Serialize)]
263pub struct FunctionTool {
264 pub name: String,
266 pub description: String,
268 pub parameters: Value,
270}
271
272#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
309
310pub struct StructuredOutputFormat {
311 pub name: String,
313 pub description: Option<String>,
315 pub schema: Option<Value>,
317 pub strict: Option<bool>,
319}
320
321#[derive(Debug, Clone, Serialize)]
323pub struct Tool {
324 #[serde(rename = "type")]
326 pub tool_type: String,
327 pub function: FunctionTool,
329}
330
331#[derive(Debug, Clone, Default)]
334pub enum ToolChoice {
335 Any,
338
339 #[default]
342 Auto,
343
344 Tool(String),
348
349 None,
352}
353
354impl Serialize for ToolChoice {
355 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
356 where
357 S: serde::Serializer,
358 {
359 match self {
360 ToolChoice::Any => serializer.serialize_str("required"),
361 ToolChoice::Auto => serializer.serialize_str("auto"),
362 ToolChoice::None => serializer.serialize_str("none"),
363 ToolChoice::Tool(name) => {
364 use serde::ser::SerializeMap;
365
366 let mut map = serializer.serialize_map(Some(2))?;
368 map.serialize_entry("type", "function")?;
369
370 let mut function_obj = std::collections::HashMap::new();
372 function_obj.insert("name", name.as_str());
373
374 map.serialize_entry("function", &function_obj)?;
375 map.end()
376 }
377 }
378 }
379}
380
381pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
382 fn text(&self) -> Option<String>;
383 fn tool_calls(&self) -> Option<Vec<ToolCall>>;
384 fn thinking(&self) -> Option<String> {
385 None
386 }
387 fn usage(&self) -> Option<Usage> {
388 None
389 }
390}
391
392#[async_trait]
394pub trait ChatProvider: Sync + Send {
395 async fn chat(
406 &self,
407 messages: &[ChatMessage],
408 json_schema: Option<StructuredOutputFormat>,
409 ) -> Result<Box<dyn ChatResponse>, LLMError> {
410 self.chat_with_tools(messages, None, json_schema).await
411 }
412
413 async fn chat_with_tools(
425 &self,
426 messages: &[ChatMessage],
427 tools: Option<&[Tool]>,
428 json_schema: Option<StructuredOutputFormat>,
429 ) -> Result<Box<dyn ChatResponse>, LLMError>;
430
431 async fn chat_with_web_search(
441 &self,
442 _input: String,
443 ) -> Result<Box<dyn ChatResponse>, LLMError> {
444 Err(LLMError::Generic(
445 "Web search not supported for this provider".to_string(),
446 ))
447 }
448
449 async fn chat_stream(
460 &self,
461 _messages: &[ChatMessage],
462 _json_schema: Option<StructuredOutputFormat>,
463 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
464 {
465 Err(LLMError::Generic(
466 "Streaming not supported for this provider".to_string(),
467 ))
468 }
469
470 async fn chat_stream_struct(
488 &self,
489 _messages: &[ChatMessage],
490 _tools: Option<&[Tool]>,
491 _json_schema: Option<StructuredOutputFormat>,
492 ) -> Result<
493 std::pin::Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>,
494 LLMError,
495 > {
496 Err(LLMError::Generic(
497 "Structured streaming not supported for this provider".to_string(),
498 ))
499 }
500
501 async fn chat_stream_with_tools(
546 &self,
547 _messages: &[ChatMessage],
548 _tools: Option<&[Tool]>,
549 _json_schema: Option<StructuredOutputFormat>,
550 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, LLMError> {
551 Err(LLMError::Generic(
552 "Streaming with tools not supported for this provider".to_string(),
553 ))
554 }
555}
556
557impl fmt::Display for ReasoningEffort {
558 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
559 match self {
560 ReasoningEffort::Low => write!(f, "low"),
561 ReasoningEffort::Medium => write!(f, "medium"),
562 ReasoningEffort::High => write!(f, "high"),
563 }
564 }
565}
566
567impl ChatMessage {
568 pub fn user() -> ChatMessageBuilder {
570 ChatMessageBuilder::new(ChatRole::User)
571 }
572
573 pub fn assistant() -> ChatMessageBuilder {
575 ChatMessageBuilder::new(ChatRole::Assistant)
576 }
577}
578
579#[derive(Debug)]
581pub struct ChatMessageBuilder {
582 role: ChatRole,
583 message_type: MessageType,
584 content: String,
585}
586
587impl ChatMessageBuilder {
588 pub fn new(role: ChatRole) -> Self {
590 Self {
591 role,
592 message_type: MessageType::default(),
593 content: String::default(),
594 }
595 }
596
597 pub fn content<S: Into<String>>(mut self, content: S) -> Self {
599 self.content = content.into();
600 self
601 }
602
603 pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
605 self.message_type = MessageType::Image((image_mime, raw_bytes));
606 self
607 }
608
609 pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
611 self.message_type = MessageType::Pdf(raw_bytes);
612 self
613 }
614
615 pub fn image_url(mut self, url: impl Into<String>) -> Self {
617 self.message_type = MessageType::ImageURL(url.into());
618 self
619 }
620
621 pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
623 self.message_type = MessageType::ToolUse(tools);
624 self
625 }
626
627 pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
629 self.message_type = MessageType::ToolResult(tools);
630 self
631 }
632
633 pub fn build(self) -> ChatMessage {
635 ChatMessage {
636 role: self.role,
637 message_type: self.message_type,
638 content: self.content,
639 }
640 }
641}
642
643#[allow(dead_code)]
654pub(crate) fn create_sse_stream<F>(
655 response: reqwest::Response,
656 parser: F,
657) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
658where
659 F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
660{
661 let stream = response
662 .bytes_stream()
663 .scan(
664 (String::default(), Vec::default()),
665 move |(buffer, utf8_buffer), chunk| {
666 let result = match chunk {
667 Ok(bytes) => {
668 utf8_buffer.extend_from_slice(&bytes);
669
670 match String::from_utf8(utf8_buffer.clone()) {
671 Ok(text) => {
672 buffer.push_str(&text);
673 utf8_buffer.clear();
674 }
675 Err(e) => {
676 let valid_up_to = e.utf8_error().valid_up_to();
677 if valid_up_to > 0 {
678 let valid =
681 String::from_utf8_lossy(&utf8_buffer[..valid_up_to]);
682 buffer.push_str(&valid);
683 utf8_buffer.drain(..valid_up_to);
684 }
685 }
686 }
687
688 let mut results = Vec::default();
689
690 while let Some(pos) = buffer.find("\n\n") {
691 let event = buffer[..pos + 2].to_string();
692 buffer.drain(..pos + 2);
693
694 match parser(&event) {
695 Ok(Some(content)) => results.push(Ok(content)),
696 Ok(None) => {}
697 Err(e) => results.push(Err(e)),
698 }
699 }
700
701 Some(results)
702 }
703 Err(e) => Some(vec![Err(LLMError::HttpError(e.to_string()))]),
704 };
705
706 async move { result }
707 },
708 )
709 .flat_map(futures::stream::iter);
710
711 Box::pin(stream)
712}
713
714#[cfg(not(target_arch = "wasm32"))]
715pub mod utils {
716 use crate::error::LLMError;
717 use reqwest::Response;
718 pub async fn check_response_status(response: Response) -> Result<Response, LLMError> {
719 if !response.status().is_success() {
720 let status = response.status();
721 let error_text = response.text().await?;
722 return Err(LLMError::ResponseFormatError {
723 message: format!("API returned error status: {status}"),
724 raw_response: error_text,
725 });
726 }
727 Ok(response)
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use bytes::Bytes;
735 use futures::stream::StreamExt;
736
737 #[test]
738 fn test_chat_message_builder_user() {
739 let msg = ChatMessage::user().content("hello").build();
740 assert_eq!(msg.role, ChatRole::User);
741 assert_eq!(msg.content, "hello");
742 assert!(matches!(msg.message_type, MessageType::Text));
743 }
744
745 #[test]
746 fn test_chat_message_builder_assistant() {
747 let msg = ChatMessage::assistant().content("reply").build();
748 assert_eq!(msg.role, ChatRole::Assistant);
749 assert_eq!(msg.content, "reply");
750 }
751
752 #[test]
753 fn test_chat_message_builder_image() {
754 let msg = ChatMessage::user()
755 .content("describe")
756 .image(ImageMime::PNG, vec![1, 2, 3])
757 .build();
758 assert!(matches!(msg.message_type, MessageType::Image(_)));
759 }
760
761 #[test]
762 fn test_chat_message_builder_pdf() {
763 let msg = ChatMessage::user()
764 .content("read")
765 .pdf(vec![4, 5, 6])
766 .build();
767 assert!(matches!(msg.message_type, MessageType::Pdf(_)));
768 }
769
770 #[test]
771 fn test_chat_message_builder_tool_use() {
772 let tc = crate::ToolCall {
773 id: "t1".to_string(),
774 call_type: "function".to_string(),
775 function: crate::FunctionCall {
776 name: "tool".to_string(),
777 arguments: "{}".to_string(),
778 },
779 };
780 let msg = ChatMessage::assistant()
781 .content("calling tool")
782 .tool_use(vec![tc])
783 .build();
784 assert!(matches!(msg.message_type, MessageType::ToolUse(_)));
785 }
786
787 #[test]
788 fn test_chat_message_builder_tool_result() {
789 let tc = crate::ToolCall {
790 id: "t1".to_string(),
791 call_type: "function".to_string(),
792 function: crate::FunctionCall {
793 name: "tool".to_string(),
794 arguments: "result".to_string(),
795 },
796 };
797 let msg = ChatMessageBuilder::new(ChatRole::Tool)
798 .tool_result(vec![tc])
799 .build();
800 assert!(matches!(msg.message_type, MessageType::ToolResult(_)));
801 assert_eq!(msg.role, ChatRole::Tool);
802 }
803
804 #[test]
805 fn test_chat_role_display() {
806 assert_eq!(format!("{}", ChatRole::System), "system");
807 assert_eq!(format!("{}", ChatRole::User), "user");
808 assert_eq!(format!("{}", ChatRole::Assistant), "assistant");
809 assert_eq!(format!("{}", ChatRole::Tool), "tool");
810 }
811
812 #[test]
813 fn test_image_mime_mime_type() {
814 assert_eq!(ImageMime::JPEG.mime_type(), "image/jpeg");
815 assert_eq!(ImageMime::PNG.mime_type(), "image/png");
816 assert_eq!(ImageMime::GIF.mime_type(), "image/gif");
817 assert_eq!(ImageMime::WEBP.mime_type(), "image/webp");
818 }
819
820 #[test]
821 fn test_reasoning_effort_display() {
822 assert_eq!(format!("{}", ReasoningEffort::Low), "low");
823 assert_eq!(format!("{}", ReasoningEffort::Medium), "medium");
824 assert_eq!(format!("{}", ReasoningEffort::High), "high");
825 }
826
827 #[test]
828 fn test_tool_choice_serialization() {
829 let any_json = serde_json::to_value(&ToolChoice::Any).unwrap();
830 assert_eq!(any_json, "required");
831
832 let auto_json = serde_json::to_value(&ToolChoice::Auto).unwrap();
833 assert_eq!(auto_json, "auto");
834
835 let none_json = serde_json::to_value(&ToolChoice::None).unwrap();
836 assert_eq!(none_json, "none");
837
838 let tool_json = serde_json::to_value(ToolChoice::Tool("my_func".to_string())).unwrap();
839 assert_eq!(tool_json["type"], "function");
840 assert_eq!(tool_json["function"]["name"], "my_func");
841 }
842
843 #[test]
844 fn test_structured_output_format_roundtrip() {
845 let format = StructuredOutputFormat {
846 name: "Test".to_string(),
847 description: Some("A test".to_string()),
848 schema: Some(serde_json::json!({"type": "object"})),
849 strict: Some(true),
850 };
851 let json = serde_json::to_string(&format).unwrap();
852 let parsed: StructuredOutputFormat = serde_json::from_str(&json).unwrap();
853 assert_eq!(parsed, format);
854 }
855
856 #[test]
857 fn test_structured_output_format_minimal() {
858 let json_str = r#"{"name":"Minimal"}"#;
859 let parsed: StructuredOutputFormat = serde_json::from_str(json_str).unwrap();
860 assert_eq!(parsed.name, "Minimal");
861 assert_eq!(parsed.description, None);
862 assert_eq!(parsed.schema, None);
863 assert_eq!(parsed.strict, None);
864 }
865
866 #[test]
867 fn test_chat_message_builder_image_url() {
868 let msg = ChatMessage::user()
869 .image_url("https://example.com/img.png")
870 .content("describe this")
871 .build();
872 assert!(matches!(msg.message_type, MessageType::ImageURL(_)));
873 }
874
875 #[tokio::test]
876 async fn test_create_sse_stream_handles_split_utf8() {
877 let test_data = "data: Positive reactions\n\n".as_bytes();
878
879 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
880 Ok(Bytes::from(&test_data[..10])),
881 Ok(Bytes::from(&test_data[10..])),
882 ];
883
884 let mock_response = create_mock_response(chunks);
885
886 let parser = |event: &str| -> Result<Option<String>, LLMError> {
887 if let Some(content) = event.strip_prefix("data: ") {
888 let content = content.trim();
889 if content.is_empty() {
890 return Ok(None);
891 }
892 Ok(Some(content.to_string()))
893 } else {
894 Ok(None)
895 }
896 };
897
898 let mut stream = create_sse_stream(mock_response, parser);
899
900 let mut results = Vec::new();
901 while let Some(result) = stream.next().await {
902 results.push(result);
903 }
904
905 assert_eq!(results.len(), 1);
906 assert_eq!(results[0].as_ref().unwrap(), "Positive reactions");
907 }
908
909 #[tokio::test]
910 async fn test_create_sse_stream_handles_split_sse_events() {
911 let event1 = "data: First event\n\n";
912 let event2 = "data: Second event\n\n";
913 let combined = format!("{}{}", event1, event2);
914 let test_data = combined.as_bytes().to_vec();
915
916 let split_point = event1.len() + 5;
917 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
918 Ok(Bytes::from(test_data[..split_point].to_vec())),
919 Ok(Bytes::from(test_data[split_point..].to_vec())),
920 ];
921
922 let mock_response = create_mock_response(chunks);
923
924 let parser = |event: &str| -> Result<Option<String>, LLMError> {
925 if let Some(content) = event.strip_prefix("data: ") {
926 let content = content.trim();
927 if content.is_empty() {
928 return Ok(None);
929 }
930 Ok(Some(content.to_string()))
931 } else {
932 Ok(None)
933 }
934 };
935
936 let mut stream = create_sse_stream(mock_response, parser);
937
938 let mut results = Vec::new();
939 while let Some(result) = stream.next().await {
940 results.push(result);
941 }
942
943 assert_eq!(results.len(), 2);
944 assert_eq!(results[0].as_ref().unwrap(), "First event");
945 assert_eq!(results[1].as_ref().unwrap(), "Second event");
946 }
947
948 #[tokio::test]
949 async fn test_create_sse_stream_handles_multibyte_utf8_split() {
950 let multibyte_char = "✨";
951 let event = format!("data: Star {}\n\n", multibyte_char);
952 let test_data = event.as_bytes().to_vec();
953
954 let emoji_start = event.find(multibyte_char).unwrap();
955 let split_in_emoji = emoji_start + 1;
956
957 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
958 Ok(Bytes::from(test_data[..split_in_emoji].to_vec())),
959 Ok(Bytes::from(test_data[split_in_emoji..].to_vec())),
960 ];
961
962 let mock_response = create_mock_response(chunks);
963
964 let parser = |event: &str| -> Result<Option<String>, LLMError> {
965 if let Some(content) = event.strip_prefix("data: ") {
966 let content = content.trim();
967 if content.is_empty() {
968 return Ok(None);
969 }
970 Ok(Some(content.to_string()))
971 } else {
972 Ok(None)
973 }
974 };
975
976 let mut stream = create_sse_stream(mock_response, parser);
977
978 let mut results = Vec::new();
979 while let Some(result) = stream.next().await {
980 results.push(result);
981 }
982
983 assert_eq!(results.len(), 1);
984 assert_eq!(
985 results[0].as_ref().unwrap(),
986 &format!("Star {}", multibyte_char)
987 );
988 }
989
990 fn create_mock_response(chunks: Vec<Result<Bytes, reqwest::Error>>) -> reqwest::Response {
991 use http_body_util::StreamBody;
992 use reqwest::Body;
993
994 let frame_stream = futures::stream::iter(
995 chunks
996 .into_iter()
997 .map(|chunk| chunk.map(hyper::body::Frame::data)),
998 );
999
1000 let body = StreamBody::new(frame_stream);
1001 let body = Body::wrap(body);
1002
1003 let http_response = http::Response::builder().status(200).body(body).unwrap();
1004
1005 http_response.into()
1006 }
1007}