1use std::collections::HashMap;
2use std::fmt;
3use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::stream::{Stream, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{ToolCall, error::LLMError};
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Usage {
15 #[serde(alias = "input_tokens")]
17 pub prompt_tokens: u32,
18 #[serde(alias = "output_tokens")]
20 pub completion_tokens: u32,
21 pub total_tokens: u32,
23 #[serde(
25 skip_serializing_if = "Option::is_none",
26 alias = "output_tokens_details"
27 )]
28 pub completion_tokens_details: Option<CompletionTokensDetails>,
29 #[serde(
31 skip_serializing_if = "Option::is_none",
32 alias = "input_tokens_details"
33 )]
34 pub prompt_tokens_details: Option<PromptTokensDetails>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct StreamResponse {
40 pub choices: Vec<StreamChoice>,
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub usage: Option<Usage>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct StreamChoice {
50 pub delta: StreamDelta,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct StreamDelta {
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub content: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub tool_calls: Option<Vec<ToolCall>>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
71pub enum StreamChunk {
72 Text(String),
74
75 ToolUseStart {
77 index: usize,
79 id: String,
81 name: String,
83 },
84
85 ToolUseInputDelta {
87 index: usize,
89 partial_json: String,
91 },
92
93 ToolUseComplete {
95 index: usize,
97 tool_call: ToolCall,
99 },
100
101 Done {
103 stop_reason: String,
105 },
106 Usage(Usage),
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
111pub struct CompletionTokensDetails {
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub reasoning_tokens: Option<u32>,
115 #[serde(skip_serializing_if = "Option::is_none")]
117 pub audio_tokens: Option<u32>,
118}
119
120#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
122pub struct PromptTokensDetails {
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub cached_tokens: Option<u32>,
126 #[serde(skip_serializing_if = "Option::is_none")]
128 pub audio_tokens: Option<u32>,
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
133pub enum ChatRole {
134 System,
136 User,
138 Assistant,
140 Tool,
142}
143
144impl fmt::Display for ChatRole {
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 let value = match self {
147 ChatRole::System => "system",
148 ChatRole::User => "user",
149 ChatRole::Assistant => "assistant",
150 ChatRole::Tool => "tool",
151 };
152 f.write_str(value)
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
158#[non_exhaustive]
159pub enum ImageMime {
160 JPEG,
162 PNG,
164 GIF,
166 WEBP,
168}
169
170impl ImageMime {
171 pub fn mime_type(&self) -> &'static str {
172 match self {
173 ImageMime::JPEG => "image/jpeg",
174 ImageMime::PNG => "image/png",
175 ImageMime::GIF => "image/gif",
176 ImageMime::WEBP => "image/webp",
177 }
178 }
179}
180
181#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
183pub enum MessageType {
184 #[default]
186 Text,
187 Image((ImageMime, Vec<u8>)),
189 Pdf(Vec<u8>),
191 ImageURL(String),
193 ToolUse(Vec<ToolCall>),
195 ToolResult(Vec<ToolCall>),
197}
198
199pub enum ReasoningEffort {
201 Low,
203 Medium,
205 High,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ChatMessage {
212 pub role: ChatRole,
214 pub message_type: MessageType,
216 pub content: String,
218}
219
220#[derive(Debug, Clone, Serialize)]
222pub struct ParameterProperty {
223 #[serde(rename = "type")]
225 pub property_type: String,
226 pub description: String,
228 #[serde(skip_serializing_if = "Option::is_none")]
230 pub items: Option<Box<ParameterProperty>>,
231 #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
233 pub enum_list: Option<Vec<String>>,
234}
235
236#[derive(Debug, Clone, Serialize)]
238pub struct ParametersSchema {
239 #[serde(rename = "type")]
241 pub schema_type: String,
242 pub properties: HashMap<String, ParameterProperty>,
244 pub required: Vec<String>,
246}
247
248#[derive(Debug, Clone, Serialize)]
258pub struct FunctionTool {
259 pub name: String,
261 pub description: String,
263 pub parameters: Value,
265}
266
267#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
304
305pub struct StructuredOutputFormat {
306 pub name: String,
308 pub description: Option<String>,
310 pub schema: Option<Value>,
312 pub strict: Option<bool>,
314}
315
316#[derive(Debug, Clone, Serialize)]
318pub struct Tool {
319 #[serde(rename = "type")]
321 pub tool_type: String,
322 pub function: FunctionTool,
324}
325
326#[derive(Debug, Clone, Default)]
329pub enum ToolChoice {
330 Any,
333
334 #[default]
337 Auto,
338
339 Tool(String),
343
344 None,
347}
348
349impl Serialize for ToolChoice {
350 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
351 where
352 S: serde::Serializer,
353 {
354 match self {
355 ToolChoice::Any => serializer.serialize_str("required"),
356 ToolChoice::Auto => serializer.serialize_str("auto"),
357 ToolChoice::None => serializer.serialize_str("none"),
358 ToolChoice::Tool(name) => {
359 use serde::ser::SerializeMap;
360
361 let mut map = serializer.serialize_map(Some(2))?;
363 map.serialize_entry("type", "function")?;
364
365 let mut function_obj = std::collections::HashMap::new();
367 function_obj.insert("name", name.as_str());
368
369 map.serialize_entry("function", &function_obj)?;
370 map.end()
371 }
372 }
373 }
374}
375
376pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
377 fn text(&self) -> Option<String>;
378 fn tool_calls(&self) -> Option<Vec<ToolCall>>;
379 fn thinking(&self) -> Option<String> {
380 None
381 }
382 fn usage(&self) -> Option<Usage> {
383 None
384 }
385}
386
387#[async_trait]
389pub trait ChatProvider: Sync + Send {
390 async fn chat(
401 &self,
402 messages: &[ChatMessage],
403 json_schema: Option<StructuredOutputFormat>,
404 ) -> Result<Box<dyn ChatResponse>, LLMError> {
405 self.chat_with_tools(messages, None, json_schema).await
406 }
407
408 async fn chat_with_tools(
420 &self,
421 messages: &[ChatMessage],
422 tools: Option<&[Tool]>,
423 json_schema: Option<StructuredOutputFormat>,
424 ) -> Result<Box<dyn ChatResponse>, LLMError>;
425
426 async fn chat_with_web_search(
436 &self,
437 _input: String,
438 ) -> Result<Box<dyn ChatResponse>, LLMError> {
439 Err(LLMError::Generic(
440 "Web search not supported for this provider".to_string(),
441 ))
442 }
443
444 async fn chat_stream(
455 &self,
456 _messages: &[ChatMessage],
457 _json_schema: Option<StructuredOutputFormat>,
458 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
459 {
460 Err(LLMError::Generic(
461 "Streaming not supported for this provider".to_string(),
462 ))
463 }
464
465 async fn chat_stream_struct(
483 &self,
484 _messages: &[ChatMessage],
485 _tools: Option<&[Tool]>,
486 _json_schema: Option<StructuredOutputFormat>,
487 ) -> Result<
488 std::pin::Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>,
489 LLMError,
490 > {
491 Err(LLMError::Generic(
492 "Structured streaming not supported for this provider".to_string(),
493 ))
494 }
495
496 async fn chat_stream_with_tools(
541 &self,
542 _messages: &[ChatMessage],
543 _tools: Option<&[Tool]>,
544 _json_schema: Option<StructuredOutputFormat>,
545 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, LLMError> {
546 Err(LLMError::Generic(
547 "Streaming with tools not supported for this provider".to_string(),
548 ))
549 }
550}
551
552impl fmt::Display for ReasoningEffort {
553 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
554 match self {
555 ReasoningEffort::Low => write!(f, "low"),
556 ReasoningEffort::Medium => write!(f, "medium"),
557 ReasoningEffort::High => write!(f, "high"),
558 }
559 }
560}
561
562impl ChatMessage {
563 pub fn user() -> ChatMessageBuilder {
565 ChatMessageBuilder::new(ChatRole::User)
566 }
567
568 pub fn assistant() -> ChatMessageBuilder {
570 ChatMessageBuilder::new(ChatRole::Assistant)
571 }
572}
573
574#[derive(Debug)]
576pub struct ChatMessageBuilder {
577 role: ChatRole,
578 message_type: MessageType,
579 content: String,
580}
581
582impl ChatMessageBuilder {
583 pub fn new(role: ChatRole) -> Self {
585 Self {
586 role,
587 message_type: MessageType::default(),
588 content: String::new(),
589 }
590 }
591
592 pub fn content<S: Into<String>>(mut self, content: S) -> Self {
594 self.content = content.into();
595 self
596 }
597
598 pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
600 self.message_type = MessageType::Image((image_mime, raw_bytes));
601 self
602 }
603
604 pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
606 self.message_type = MessageType::Pdf(raw_bytes);
607 self
608 }
609
610 pub fn image_url(mut self, url: impl Into<String>) -> Self {
612 self.message_type = MessageType::ImageURL(url.into());
613 self
614 }
615
616 pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
618 self.message_type = MessageType::ToolUse(tools);
619 self
620 }
621
622 pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
624 self.message_type = MessageType::ToolResult(tools);
625 self
626 }
627
628 pub fn build(self) -> ChatMessage {
630 ChatMessage {
631 role: self.role,
632 message_type: self.message_type,
633 content: self.content,
634 }
635 }
636}
637
638#[allow(dead_code)]
649pub(crate) fn create_sse_stream<F>(
650 response: reqwest::Response,
651 parser: F,
652) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
653where
654 F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
655{
656 let stream = response
657 .bytes_stream()
658 .scan(
659 (String::new(), Vec::new()),
660 move |(buffer, utf8_buffer), chunk| {
661 let result = match chunk {
662 Ok(bytes) => {
663 utf8_buffer.extend_from_slice(&bytes);
664
665 match String::from_utf8(utf8_buffer.clone()) {
666 Ok(text) => {
667 buffer.push_str(&text);
668 utf8_buffer.clear();
669 }
670 Err(e) => {
671 let valid_up_to = e.utf8_error().valid_up_to();
672 if valid_up_to > 0 {
673 let valid =
676 String::from_utf8_lossy(&utf8_buffer[..valid_up_to]);
677 buffer.push_str(&valid);
678 utf8_buffer.drain(..valid_up_to);
679 }
680 }
681 }
682
683 let mut results = Vec::new();
684
685 while let Some(pos) = buffer.find("\n\n") {
686 let event = buffer[..pos + 2].to_string();
687 buffer.drain(..pos + 2);
688
689 match parser(&event) {
690 Ok(Some(content)) => results.push(Ok(content)),
691 Ok(None) => {}
692 Err(e) => results.push(Err(e)),
693 }
694 }
695
696 Some(results)
697 }
698 Err(e) => Some(vec![Err(LLMError::HttpError(e.to_string()))]),
699 };
700
701 async move { result }
702 },
703 )
704 .flat_map(futures::stream::iter);
705
706 Box::pin(stream)
707}
708
709#[cfg(not(target_arch = "wasm32"))]
710pub mod utils {
711 use crate::error::LLMError;
712 use reqwest::Response;
713 pub async fn check_response_status(response: Response) -> Result<Response, LLMError> {
714 if !response.status().is_success() {
715 let status = response.status();
716 let error_text = response.text().await?;
717 return Err(LLMError::ResponseFormatError {
718 message: format!("API returned error status: {status}"),
719 raw_response: error_text,
720 });
721 }
722 Ok(response)
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use bytes::Bytes;
730 use futures::stream::StreamExt;
731
732 #[tokio::test]
733 async fn test_create_sse_stream_handles_split_utf8() {
734 let test_data = "data: Positive reactions\n\n".as_bytes();
735
736 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
737 Ok(Bytes::from(&test_data[..10])),
738 Ok(Bytes::from(&test_data[10..])),
739 ];
740
741 let mock_response = create_mock_response(chunks);
742
743 let parser = |event: &str| -> Result<Option<String>, LLMError> {
744 if let Some(content) = event.strip_prefix("data: ") {
745 let content = content.trim();
746 if content.is_empty() {
747 return Ok(None);
748 }
749 Ok(Some(content.to_string()))
750 } else {
751 Ok(None)
752 }
753 };
754
755 let mut stream = create_sse_stream(mock_response, parser);
756
757 let mut results = Vec::new();
758 while let Some(result) = stream.next().await {
759 results.push(result);
760 }
761
762 assert_eq!(results.len(), 1);
763 assert_eq!(results[0].as_ref().unwrap(), "Positive reactions");
764 }
765
766 #[tokio::test]
767 async fn test_create_sse_stream_handles_split_sse_events() {
768 let event1 = "data: First event\n\n";
769 let event2 = "data: Second event\n\n";
770 let combined = format!("{}{}", event1, event2);
771 let test_data = combined.as_bytes().to_vec();
772
773 let split_point = event1.len() + 5;
774 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
775 Ok(Bytes::from(test_data[..split_point].to_vec())),
776 Ok(Bytes::from(test_data[split_point..].to_vec())),
777 ];
778
779 let mock_response = create_mock_response(chunks);
780
781 let parser = |event: &str| -> Result<Option<String>, LLMError> {
782 if let Some(content) = event.strip_prefix("data: ") {
783 let content = content.trim();
784 if content.is_empty() {
785 return Ok(None);
786 }
787 Ok(Some(content.to_string()))
788 } else {
789 Ok(None)
790 }
791 };
792
793 let mut stream = create_sse_stream(mock_response, parser);
794
795 let mut results = Vec::new();
796 while let Some(result) = stream.next().await {
797 results.push(result);
798 }
799
800 assert_eq!(results.len(), 2);
801 assert_eq!(results[0].as_ref().unwrap(), "First event");
802 assert_eq!(results[1].as_ref().unwrap(), "Second event");
803 }
804
805 #[tokio::test]
806 async fn test_create_sse_stream_handles_multibyte_utf8_split() {
807 let multibyte_char = "✨";
808 let event = format!("data: Star {}\n\n", multibyte_char);
809 let test_data = event.as_bytes().to_vec();
810
811 let emoji_start = event.find(multibyte_char).unwrap();
812 let split_in_emoji = emoji_start + 1;
813
814 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
815 Ok(Bytes::from(test_data[..split_in_emoji].to_vec())),
816 Ok(Bytes::from(test_data[split_in_emoji..].to_vec())),
817 ];
818
819 let mock_response = create_mock_response(chunks);
820
821 let parser = |event: &str| -> Result<Option<String>, LLMError> {
822 if let Some(content) = event.strip_prefix("data: ") {
823 let content = content.trim();
824 if content.is_empty() {
825 return Ok(None);
826 }
827 Ok(Some(content.to_string()))
828 } else {
829 Ok(None)
830 }
831 };
832
833 let mut stream = create_sse_stream(mock_response, parser);
834
835 let mut results = Vec::new();
836 while let Some(result) = stream.next().await {
837 results.push(result);
838 }
839
840 assert_eq!(results.len(), 1);
841 assert_eq!(
842 results[0].as_ref().unwrap(),
843 &format!("Star {}", multibyte_char)
844 );
845 }
846
847 fn create_mock_response(chunks: Vec<Result<Bytes, reqwest::Error>>) -> reqwest::Response {
848 use http_body_util::StreamBody;
849 use reqwest::Body;
850
851 let frame_stream = futures::stream::iter(
852 chunks
853 .into_iter()
854 .map(|chunk| chunk.map(hyper::body::Frame::data)),
855 );
856
857 let body = StreamBody::new(frame_stream);
858 let body = Body::wrap(body);
859
860 let http_response = http::Response::builder().status(200).body(body).unwrap();
861
862 http_response.into()
863 }
864}