1use std::collections::HashMap;
2use std::fmt;
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{ToolCall, error::LLMError};
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Usage {
15 #[serde(alias = "input_tokens")]
17 pub prompt_tokens: u32,
18 #[serde(alias = "output_tokens")]
20 pub completion_tokens: u32,
21 pub total_tokens: u32,
23 #[serde(
25 skip_serializing_if = "Option::is_none",
26 alias = "output_tokens_details"
27 )]
28 pub completion_tokens_details: Option<CompletionTokensDetails>,
29 #[serde(
31 skip_serializing_if = "Option::is_none",
32 alias = "input_tokens_details"
33 )]
34 pub prompt_tokens_details: Option<PromptTokensDetails>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct StreamResponse {
40 pub choices: Vec<StreamChoice>,
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub usage: Option<Usage>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct StreamChoice {
50 pub delta: StreamDelta,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct StreamDelta {
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub content: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub tool_calls: Option<Vec<ToolCall>>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
71pub enum StreamChunk {
72 Text(String),
74
75 ToolUseStart {
77 index: usize,
79 id: String,
81 name: String,
83 },
84
85 ToolUseInputDelta {
87 index: usize,
89 partial_json: String,
91 },
92
93 ToolUseComplete {
95 index: usize,
97 tool_call: ToolCall,
99 },
100
101 Done {
103 stop_reason: String,
105 },
106 Usage(Usage),
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
111pub struct CompletionTokensDetails {
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub reasoning_tokens: Option<u32>,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub audio_tokens: Option<u32>,
118}
119
120#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
122pub struct PromptTokensDetails {
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub cached_tokens: Option<u32>,
126 #[serde(skip_serializing_if = "Option::is_none")]
128 pub audio_tokens: Option<u32>,
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
133pub enum ChatRole {
134 System,
136 User,
138 Assistant,
140 Tool,
142}
143
144impl fmt::Display for ChatRole {
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 let value = match self {
147 ChatRole::System => "system",
148 ChatRole::User => "user",
149 ChatRole::Assistant => "assistant",
150 ChatRole::Tool => "tool",
151 };
152 f.write_str(value)
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
158#[non_exhaustive]
159pub enum ImageMime {
160 JPEG,
162 PNG,
164 GIF,
166 WEBP,
168}
169
170impl ImageMime {
171 pub fn mime_type(&self) -> &'static str {
172 match self {
173 ImageMime::JPEG => "image/jpeg",
174 ImageMime::PNG => "image/png",
175 ImageMime::GIF => "image/gif",
176 ImageMime::WEBP => "image/webp",
177 }
178 }
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
183pub enum MessageType {
184 #[default]
186 Text,
187 Image((ImageMime, Vec<u8>)),
189 Pdf(Vec<u8>),
191 ImageURL(String),
193 ToolUse(Vec<ToolCall>),
195 ToolResult(Vec<ToolCall>),
197}
198
199pub enum ReasoningEffort {
201 Low,
203 Medium,
205 High,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ChatMessage {
212 pub role: ChatRole,
214 pub message_type: MessageType,
216 pub content: String,
218}
219
220#[derive(Debug, Clone, Serialize)]
222pub struct ParameterProperty {
223 #[serde(rename = "type")]
225 pub property_type: String,
226 pub description: String,
228 #[serde(skip_serializing_if = "Option::is_none")]
230 pub items: Option<Box<ParameterProperty>>,
231 #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
233 pub enum_list: Option<Vec<String>>,
234}
235
236#[derive(Debug, Clone, Serialize)]
238pub struct ParametersSchema {
239 #[serde(rename = "type")]
241 pub schema_type: String,
242 pub properties: HashMap<String, ParameterProperty>,
244 pub required: Vec<String>,
246}
247
248#[derive(Debug, Clone, Serialize)]
258pub struct FunctionTool {
259 pub name: String,
261 pub description: String,
263 pub parameters: Value,
265}
266
267#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
304
305pub struct StructuredOutputFormat {
306 pub name: String,
308 pub description: Option<String>,
310 pub schema: Option<Value>,
312 pub strict: Option<bool>,
314}
315
316#[derive(Debug, Clone, Serialize)]
318pub struct Tool {
319 #[serde(rename = "type")]
321 pub tool_type: String,
322 pub function: FunctionTool,
324}
325
326#[derive(Debug, Clone, Default)]
329pub enum ToolChoice {
330 Any,
333
334 #[default]
337 Auto,
338
339 Tool(String),
343
344 None,
347}
348
349impl Serialize for ToolChoice {
350 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
351 where
352 S: serde::Serializer,
353 {
354 match self {
355 ToolChoice::Any => serializer.serialize_str("required"),
356 ToolChoice::Auto => serializer.serialize_str("auto"),
357 ToolChoice::None => serializer.serialize_str("none"),
358 ToolChoice::Tool(name) => {
359 use serde::ser::SerializeMap;
360
361 let mut map = serializer.serialize_map(Some(2))?;
363 map.serialize_entry("type", "function")?;
364
365 let mut function_obj = std::collections::HashMap::new();
367 function_obj.insert("name", name.as_str());
368
369 map.serialize_entry("function", &function_obj)?;
370 map.end()
371 }
372 }
373 }
374}
375
376pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
377 fn text(&self) -> Option<String>;
378 fn tool_calls(&self) -> Option<Vec<ToolCall>>;
379 fn thinking(&self) -> Option<String> {
380 None
381 }
382 fn usage(&self) -> Option<Usage> {
383 None
384 }
385}
386
387#[async_trait]
389pub trait ChatProvider: Sync + Send {
390 async fn chat(
401 &self,
402 messages: &[ChatMessage],
403 json_schema: Option<StructuredOutputFormat>,
404 ) -> Result<Box<dyn ChatResponse>, LLMError> {
405 self.chat_with_tools(messages, None, json_schema).await
406 }
407
408 async fn chat_with_tools(
420 &self,
421 messages: &[ChatMessage],
422 tools: Option<&[Tool]>,
423 json_schema: Option<StructuredOutputFormat>,
424 ) -> Result<Box<dyn ChatResponse>, LLMError>;
425
426 async fn chat_with_web_search(
436 &self,
437 _input: String,
438 ) -> Result<Box<dyn ChatResponse>, LLMError> {
439 Err(LLMError::Generic(
440 "Web search not supported for this provider".to_string(),
441 ))
442 }
443
444 async fn chat_stream(
455 &self,
456 _messages: &[ChatMessage],
457 _json_schema: Option<StructuredOutputFormat>,
458 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
459 {
460 Err(LLMError::Generic(
461 "Streaming not supported for this provider".to_string(),
462 ))
463 }
464
465 async fn chat_stream_struct(
483 &self,
484 _messages: &[ChatMessage],
485 _tools: Option<&[Tool]>,
486 _json_schema: Option<StructuredOutputFormat>,
487 ) -> Result<
488 std::pin::Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>,
489 LLMError,
490 > {
491 Err(LLMError::Generic(
492 "Structured streaming not supported for this provider".to_string(),
493 ))
494 }
495
496 async fn chat_stream_with_tools(
541 &self,
542 _messages: &[ChatMessage],
543 _tools: Option<&[Tool]>,
544 _json_schema: Option<StructuredOutputFormat>,
545 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, LLMError> {
546 Err(LLMError::Generic(
547 "Streaming with tools not supported for this provider".to_string(),
548 ))
549 }
550}
551
552impl fmt::Display for ReasoningEffort {
553 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
554 match self {
555 ReasoningEffort::Low => write!(f, "low"),
556 ReasoningEffort::Medium => write!(f, "medium"),
557 ReasoningEffort::High => write!(f, "high"),
558 }
559 }
560}
561
562impl ChatMessage {
563 pub fn user() -> ChatMessageBuilder {
565 ChatMessageBuilder::new(ChatRole::User)
566 }
567
568 pub fn assistant() -> ChatMessageBuilder {
570 ChatMessageBuilder::new(ChatRole::Assistant)
571 }
572}
573
574#[derive(Debug)]
576pub struct ChatMessageBuilder {
577 role: ChatRole,
578 message_type: MessageType,
579 content: String,
580}
581
582impl ChatMessageBuilder {
583 pub fn new(role: ChatRole) -> Self {
585 Self {
586 role,
587 message_type: MessageType::default(),
588 content: String::default(),
589 }
590 }
591
592 pub fn content<S: Into<String>>(mut self, content: S) -> Self {
594 self.content = content.into();
595 self
596 }
597
598 pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
600 self.message_type = MessageType::Image((image_mime, raw_bytes));
601 self
602 }
603
604 pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
606 self.message_type = MessageType::Pdf(raw_bytes);
607 self
608 }
609
610 pub fn image_url(mut self, url: impl Into<String>) -> Self {
612 self.message_type = MessageType::ImageURL(url.into());
613 self
614 }
615
616 pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
618 self.message_type = MessageType::ToolUse(tools);
619 self
620 }
621
622 pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
624 self.message_type = MessageType::ToolResult(tools);
625 self
626 }
627
628 pub fn build(self) -> ChatMessage {
630 ChatMessage {
631 role: self.role,
632 message_type: self.message_type,
633 content: self.content,
634 }
635 }
636}
637
638#[allow(dead_code)]
649pub(crate) fn create_sse_stream<F>(
650 response: reqwest::Response,
651 parser: F,
652) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
653where
654 F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
655{
656 let stream = response
657 .bytes_stream()
658 .scan(
659 (String::default(), Vec::default()),
660 move |(buffer, utf8_buffer), chunk| {
661 let result = match chunk {
662 Ok(bytes) => {
663 utf8_buffer.extend_from_slice(&bytes);
664
665 match String::from_utf8(utf8_buffer.clone()) {
666 Ok(text) => {
667 buffer.push_str(&text);
668 utf8_buffer.clear();
669 }
670 Err(e) => {
671 let valid_up_to = e.utf8_error().valid_up_to();
672 if valid_up_to > 0 {
673 let valid =
676 String::from_utf8_lossy(&utf8_buffer[..valid_up_to]);
677 buffer.push_str(&valid);
678 utf8_buffer.drain(..valid_up_to);
679 }
680 }
681 }
682
683 let mut results = Vec::default();
684
685 while let Some(pos) = buffer.find("\n\n") {
686 let event = buffer[..pos + 2].to_string();
687 buffer.drain(..pos + 2);
688
689 match parser(&event) {
690 Ok(Some(content)) => results.push(Ok(content)),
691 Ok(None) => {}
692 Err(e) => results.push(Err(e)),
693 }
694 }
695
696 Some(results)
697 }
698 Err(e) => Some(vec![Err(LLMError::HttpError(e.to_string()))]),
699 };
700
701 async move { result }
702 },
703 )
704 .flat_map(futures::stream::iter);
705
706 Box::pin(stream)
707}
708
709#[cfg(not(target_arch = "wasm32"))]
710pub mod utils {
711 use crate::error::LLMError;
712 use reqwest::Response;
713 pub async fn check_response_status(response: Response) -> Result<Response, LLMError> {
714 if !response.status().is_success() {
715 let status = response.status();
716 let error_text = response.text().await?;
717 return Err(LLMError::ResponseFormatError {
718 message: format!("API returned error status: {status}"),
719 raw_response: error_text,
720 });
721 }
722 Ok(response)
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use bytes::Bytes;
730 use futures::stream::StreamExt;
731
732 #[test]
733 fn test_chat_message_builder_user() {
734 let msg = ChatMessage::user().content("hello").build();
735 assert_eq!(msg.role, ChatRole::User);
736 assert_eq!(msg.content, "hello");
737 assert!(matches!(msg.message_type, MessageType::Text));
738 }
739
740 #[test]
741 fn test_chat_message_builder_assistant() {
742 let msg = ChatMessage::assistant().content("reply").build();
743 assert_eq!(msg.role, ChatRole::Assistant);
744 assert_eq!(msg.content, "reply");
745 }
746
747 #[test]
748 fn test_chat_message_builder_image() {
749 let msg = ChatMessage::user()
750 .content("describe")
751 .image(ImageMime::PNG, vec![1, 2, 3])
752 .build();
753 assert!(matches!(msg.message_type, MessageType::Image(_)));
754 }
755
756 #[test]
757 fn test_chat_message_builder_pdf() {
758 let msg = ChatMessage::user()
759 .content("read")
760 .pdf(vec![4, 5, 6])
761 .build();
762 assert!(matches!(msg.message_type, MessageType::Pdf(_)));
763 }
764
765 #[test]
766 fn test_chat_message_builder_tool_use() {
767 let tc = crate::ToolCall {
768 id: "t1".to_string(),
769 call_type: "function".to_string(),
770 function: crate::FunctionCall {
771 name: "tool".to_string(),
772 arguments: "{}".to_string(),
773 },
774 };
775 let msg = ChatMessage::assistant()
776 .content("calling tool")
777 .tool_use(vec![tc])
778 .build();
779 assert!(matches!(msg.message_type, MessageType::ToolUse(_)));
780 }
781
782 #[test]
783 fn test_chat_message_builder_tool_result() {
784 let tc = crate::ToolCall {
785 id: "t1".to_string(),
786 call_type: "function".to_string(),
787 function: crate::FunctionCall {
788 name: "tool".to_string(),
789 arguments: "result".to_string(),
790 },
791 };
792 let msg = ChatMessageBuilder::new(ChatRole::Tool)
793 .tool_result(vec![tc])
794 .build();
795 assert!(matches!(msg.message_type, MessageType::ToolResult(_)));
796 assert_eq!(msg.role, ChatRole::Tool);
797 }
798
799 #[test]
800 fn test_chat_role_display() {
801 assert_eq!(format!("{}", ChatRole::System), "system");
802 assert_eq!(format!("{}", ChatRole::User), "user");
803 assert_eq!(format!("{}", ChatRole::Assistant), "assistant");
804 assert_eq!(format!("{}", ChatRole::Tool), "tool");
805 }
806
807 #[test]
808 fn test_image_mime_mime_type() {
809 assert_eq!(ImageMime::JPEG.mime_type(), "image/jpeg");
810 assert_eq!(ImageMime::PNG.mime_type(), "image/png");
811 assert_eq!(ImageMime::GIF.mime_type(), "image/gif");
812 assert_eq!(ImageMime::WEBP.mime_type(), "image/webp");
813 }
814
815 #[test]
816 fn test_reasoning_effort_display() {
817 assert_eq!(format!("{}", ReasoningEffort::Low), "low");
818 assert_eq!(format!("{}", ReasoningEffort::Medium), "medium");
819 assert_eq!(format!("{}", ReasoningEffort::High), "high");
820 }
821
822 #[test]
823 fn test_tool_choice_serialization() {
824 let any_json = serde_json::to_value(&ToolChoice::Any).unwrap();
825 assert_eq!(any_json, "required");
826
827 let auto_json = serde_json::to_value(&ToolChoice::Auto).unwrap();
828 assert_eq!(auto_json, "auto");
829
830 let none_json = serde_json::to_value(&ToolChoice::None).unwrap();
831 assert_eq!(none_json, "none");
832
833 let tool_json = serde_json::to_value(ToolChoice::Tool("my_func".to_string())).unwrap();
834 assert_eq!(tool_json["type"], "function");
835 assert_eq!(tool_json["function"]["name"], "my_func");
836 }
837
838 #[test]
839 fn test_structured_output_format_roundtrip() {
840 let format = StructuredOutputFormat {
841 name: "Test".to_string(),
842 description: Some("A test".to_string()),
843 schema: Some(serde_json::json!({"type": "object"})),
844 strict: Some(true),
845 };
846 let json = serde_json::to_string(&format).unwrap();
847 let parsed: StructuredOutputFormat = serde_json::from_str(&json).unwrap();
848 assert_eq!(parsed, format);
849 }
850
851 #[test]
852 fn test_structured_output_format_minimal() {
853 let json_str = r#"{"name":"Minimal"}"#;
854 let parsed: StructuredOutputFormat = serde_json::from_str(json_str).unwrap();
855 assert_eq!(parsed.name, "Minimal");
856 assert_eq!(parsed.description, None);
857 assert_eq!(parsed.schema, None);
858 assert_eq!(parsed.strict, None);
859 }
860
861 #[test]
862 fn test_chat_message_builder_image_url() {
863 let msg = ChatMessage::user()
864 .image_url("https://example.com/img.png")
865 .content("describe this")
866 .build();
867 assert!(matches!(msg.message_type, MessageType::ImageURL(_)));
868 }
869
870 #[tokio::test]
871 async fn test_create_sse_stream_handles_split_utf8() {
872 let test_data = "data: Positive reactions\n\n".as_bytes();
873
874 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
875 Ok(Bytes::from(&test_data[..10])),
876 Ok(Bytes::from(&test_data[10..])),
877 ];
878
879 let mock_response = create_mock_response(chunks);
880
881 let parser = |event: &str| -> Result<Option<String>, LLMError> {
882 if let Some(content) = event.strip_prefix("data: ") {
883 let content = content.trim();
884 if content.is_empty() {
885 return Ok(None);
886 }
887 Ok(Some(content.to_string()))
888 } else {
889 Ok(None)
890 }
891 };
892
893 let mut stream = create_sse_stream(mock_response, parser);
894
895 let mut results = Vec::new();
896 while let Some(result) = stream.next().await {
897 results.push(result);
898 }
899
900 assert_eq!(results.len(), 1);
901 assert_eq!(results[0].as_ref().unwrap(), "Positive reactions");
902 }
903
904 #[tokio::test]
905 async fn test_create_sse_stream_handles_split_sse_events() {
906 let event1 = "data: First event\n\n";
907 let event2 = "data: Second event\n\n";
908 let combined = format!("{}{}", event1, event2);
909 let test_data = combined.as_bytes().to_vec();
910
911 let split_point = event1.len() + 5;
912 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
913 Ok(Bytes::from(test_data[..split_point].to_vec())),
914 Ok(Bytes::from(test_data[split_point..].to_vec())),
915 ];
916
917 let mock_response = create_mock_response(chunks);
918
919 let parser = |event: &str| -> Result<Option<String>, LLMError> {
920 if let Some(content) = event.strip_prefix("data: ") {
921 let content = content.trim();
922 if content.is_empty() {
923 return Ok(None);
924 }
925 Ok(Some(content.to_string()))
926 } else {
927 Ok(None)
928 }
929 };
930
931 let mut stream = create_sse_stream(mock_response, parser);
932
933 let mut results = Vec::new();
934 while let Some(result) = stream.next().await {
935 results.push(result);
936 }
937
938 assert_eq!(results.len(), 2);
939 assert_eq!(results[0].as_ref().unwrap(), "First event");
940 assert_eq!(results[1].as_ref().unwrap(), "Second event");
941 }
942
943 #[tokio::test]
944 async fn test_create_sse_stream_handles_multibyte_utf8_split() {
945 let multibyte_char = "✨";
946 let event = format!("data: Star {}\n\n", multibyte_char);
947 let test_data = event.as_bytes().to_vec();
948
949 let emoji_start = event.find(multibyte_char).unwrap();
950 let split_in_emoji = emoji_start + 1;
951
952 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
953 Ok(Bytes::from(test_data[..split_in_emoji].to_vec())),
954 Ok(Bytes::from(test_data[split_in_emoji..].to_vec())),
955 ];
956
957 let mock_response = create_mock_response(chunks);
958
959 let parser = |event: &str| -> Result<Option<String>, LLMError> {
960 if let Some(content) = event.strip_prefix("data: ") {
961 let content = content.trim();
962 if content.is_empty() {
963 return Ok(None);
964 }
965 Ok(Some(content.to_string()))
966 } else {
967 Ok(None)
968 }
969 };
970
971 let mut stream = create_sse_stream(mock_response, parser);
972
973 let mut results = Vec::new();
974 while let Some(result) = stream.next().await {
975 results.push(result);
976 }
977
978 assert_eq!(results.len(), 1);
979 assert_eq!(
980 results[0].as_ref().unwrap(),
981 &format!("Star {}", multibyte_char)
982 );
983 }
984
985 fn create_mock_response(chunks: Vec<Result<Bytes, reqwest::Error>>) -> reqwest::Response {
986 use http_body_util::StreamBody;
987 use reqwest::Body;
988
989 let frame_stream = futures::stream::iter(
990 chunks
991 .into_iter()
992 .map(|chunk| chunk.map(hyper::body::Frame::data)),
993 );
994
995 let body = StreamBody::new(frame_stream);
996 let body = Body::wrap(body);
997
998 let http_response = http::Response::builder().status(200).body(body).unwrap();
999
1000 http_response.into()
1001 }
1002}