1use std::collections::HashMap;
2use std::fmt;
3
4use async_trait::async_trait;
5use futures::stream::{Stream, StreamExt};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9use crate::{error::LLMError, ToolCall};
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct Usage {
14 #[serde(alias = "input_tokens")]
16 pub prompt_tokens: u32,
17 #[serde(alias = "output_tokens")]
19 pub completion_tokens: u32,
20 pub total_tokens: u32,
22 #[serde(
24 skip_serializing_if = "Option::is_none",
25 alias = "output_tokens_details"
26 )]
27 pub completion_tokens_details: Option<CompletionTokensDetails>,
28 #[serde(
30 skip_serializing_if = "Option::is_none",
31 alias = "input_tokens_details"
32 )]
33 pub prompt_tokens_details: Option<PromptTokensDetails>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct StreamResponse {
39 pub choices: Vec<StreamChoice>,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub usage: Option<Usage>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct StreamChoice {
49 pub delta: StreamDelta,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct StreamDelta {
56 #[serde(skip_serializing_if = "Option::is_none")]
58 pub content: Option<String>,
59 #[serde(skip_serializing_if = "Option::is_none")]
61 pub tool_calls: Option<Vec<ToolCall>>,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66pub struct CompletionTokensDetails {
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub reasoning_tokens: Option<u32>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub audio_tokens: Option<u32>,
73}
74
75#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
77pub struct PromptTokensDetails {
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub cached_tokens: Option<u32>,
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub audio_tokens: Option<u32>,
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
88pub enum ChatRole {
89 User,
91 Assistant,
93}
94
95#[derive(Debug, Clone, PartialEq, Eq)]
97#[non_exhaustive]
98pub enum ImageMime {
99 JPEG,
101 PNG,
103 GIF,
105 WEBP,
107}
108
109impl ImageMime {
110 pub fn mime_type(&self) -> &'static str {
111 match self {
112 ImageMime::JPEG => "image/jpeg",
113 ImageMime::PNG => "image/png",
114 ImageMime::GIF => "image/gif",
115 ImageMime::WEBP => "image/webp",
116 }
117 }
118}
119
120#[derive(Debug, Clone, PartialEq, Eq, Default)]
122pub enum MessageType {
123 #[default]
125 Text,
126 Image((ImageMime, Vec<u8>)),
128 Pdf(Vec<u8>),
130 ImageURL(String),
132 ToolUse(Vec<ToolCall>),
134 ToolResult(Vec<ToolCall>),
136}
137
138pub enum ReasoningEffort {
140 Low,
142 Medium,
144 High,
146}
147
148#[derive(Debug, Clone)]
150pub struct ChatMessage {
151 pub role: ChatRole,
153 pub message_type: MessageType,
155 pub content: String,
157}
158
159#[derive(Debug, Clone, Serialize)]
161pub struct ParameterProperty {
162 #[serde(rename = "type")]
164 pub property_type: String,
165 pub description: String,
167 #[serde(skip_serializing_if = "Option::is_none")]
169 pub items: Option<Box<ParameterProperty>>,
170 #[serde(skip_serializing_if = "Option::is_none", rename = "enum")]
172 pub enum_list: Option<Vec<String>>,
173}
174
175#[derive(Debug, Clone, Serialize)]
177pub struct ParametersSchema {
178 #[serde(rename = "type")]
180 pub schema_type: String,
181 pub properties: HashMap<String, ParameterProperty>,
183 pub required: Vec<String>,
185}
186
187#[derive(Debug, Clone, Serialize)]
197pub struct FunctionTool {
198 pub name: String,
200 pub description: String,
202 pub parameters: Value,
204}
205
206#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
243
244pub struct StructuredOutputFormat {
245 pub name: String,
247 pub description: Option<String>,
249 pub schema: Option<Value>,
251 pub strict: Option<bool>,
253}
254
255#[derive(Debug, Clone, Serialize)]
257pub struct Tool {
258 #[serde(rename = "type")]
260 pub tool_type: String,
261 pub function: FunctionTool,
263}
264
265#[derive(Debug, Clone, Default)]
268pub enum ToolChoice {
269 Any,
272
273 #[default]
276 Auto,
277
278 Tool(String),
282
283 None,
286}
287
288impl Serialize for ToolChoice {
289 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
290 where
291 S: serde::Serializer,
292 {
293 match self {
294 ToolChoice::Any => serializer.serialize_str("required"),
295 ToolChoice::Auto => serializer.serialize_str("auto"),
296 ToolChoice::None => serializer.serialize_str("none"),
297 ToolChoice::Tool(name) => {
298 use serde::ser::SerializeMap;
299
300 let mut map = serializer.serialize_map(Some(2))?;
302 map.serialize_entry("type", "function")?;
303
304 let mut function_obj = std::collections::HashMap::new();
306 function_obj.insert("name", name.as_str());
307
308 map.serialize_entry("function", &function_obj)?;
309 map.end()
310 }
311 }
312 }
313}
314
315pub trait ChatResponse: std::fmt::Debug + std::fmt::Display + Send + Sync {
316 fn text(&self) -> Option<String>;
317 fn tool_calls(&self) -> Option<Vec<ToolCall>>;
318 fn thinking(&self) -> Option<String> {
319 None
320 }
321 fn usage(&self) -> Option<Usage> {
322 None
323 }
324}
325
326#[async_trait]
328pub trait ChatProvider: Sync + Send {
329 async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
339 self.chat_with_tools(messages, None).await
340 }
341
342 async fn chat_with_tools(
353 &self,
354 messages: &[ChatMessage],
355 tools: Option<&[Tool]>,
356 ) -> Result<Box<dyn ChatResponse>, LLMError>;
357
358 async fn chat_with_web_search(
368 &self,
369 _input: String,
370 ) -> Result<Box<dyn ChatResponse>, LLMError> {
371 Err(LLMError::Generic(
372 "Web search not supported for this provider".to_string(),
373 ))
374 }
375
376 async fn chat_stream(
386 &self,
387 _messages: &[ChatMessage],
388 ) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
389 {
390 Err(LLMError::Generic(
391 "Streaming not supported for this provider".to_string(),
392 ))
393 }
394
395 async fn chat_stream_struct(
411 &self,
412 _messages: &[ChatMessage],
413 ) -> Result<
414 std::pin::Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>,
415 LLMError,
416 > {
417 Err(LLMError::Generic(
418 "Structured streaming not supported for this provider".to_string(),
419 ))
420 }
421
422 async fn memory_contents(&self) -> Option<Vec<ChatMessage>> {
424 None
425 }
426
427 async fn summarize_history(&self, msgs: &[ChatMessage]) -> Result<String, LLMError> {
435 let prompt = format!(
436 "Summarize in 2-3 sentences:\n{}",
437 msgs.iter()
438 .map(|m| format!("{:?}: {}", m.role, m.content))
439 .collect::<Vec<_>>()
440 .join("\n"),
441 );
442 let req = [ChatMessage::user().content(prompt).build()];
443 self.chat(&req)
444 .await?
445 .text()
446 .ok_or(LLMError::Generic("no text in summary response".into()))
447 }
448}
449
450impl fmt::Display for ReasoningEffort {
451 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452 match self {
453 ReasoningEffort::Low => write!(f, "low"),
454 ReasoningEffort::Medium => write!(f, "medium"),
455 ReasoningEffort::High => write!(f, "high"),
456 }
457 }
458}
459
460impl ChatMessage {
461 pub fn user() -> ChatMessageBuilder {
463 ChatMessageBuilder::new(ChatRole::User)
464 }
465
466 pub fn assistant() -> ChatMessageBuilder {
468 ChatMessageBuilder::new(ChatRole::Assistant)
469 }
470}
471
472#[derive(Debug)]
474pub struct ChatMessageBuilder {
475 role: ChatRole,
476 message_type: MessageType,
477 content: String,
478}
479
480impl ChatMessageBuilder {
481 pub fn new(role: ChatRole) -> Self {
483 Self {
484 role,
485 message_type: MessageType::default(),
486 content: String::new(),
487 }
488 }
489
490 pub fn content<S: Into<String>>(mut self, content: S) -> Self {
492 self.content = content.into();
493 self
494 }
495
496 pub fn image(mut self, image_mime: ImageMime, raw_bytes: Vec<u8>) -> Self {
498 self.message_type = MessageType::Image((image_mime, raw_bytes));
499 self
500 }
501
502 pub fn pdf(mut self, raw_bytes: Vec<u8>) -> Self {
504 self.message_type = MessageType::Pdf(raw_bytes);
505 self
506 }
507
508 pub fn image_url(mut self, url: impl Into<String>) -> Self {
510 self.message_type = MessageType::ImageURL(url.into());
511 self
512 }
513
514 pub fn tool_use(mut self, tools: Vec<ToolCall>) -> Self {
516 self.message_type = MessageType::ToolUse(tools);
517 self
518 }
519
520 pub fn tool_result(mut self, tools: Vec<ToolCall>) -> Self {
522 self.message_type = MessageType::ToolResult(tools);
523 self
524 }
525
526 pub fn build(self) -> ChatMessage {
528 ChatMessage {
529 role: self.role,
530 message_type: self.message_type,
531 content: self.content,
532 }
533 }
534}
535
536pub(crate) fn create_sse_stream<F>(
547 response: reqwest::Response,
548 parser: F,
549) -> std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>
550where
551 F: Fn(&str) -> Result<Option<String>, LLMError> + Send + 'static,
552{
553 let stream = response
554 .bytes_stream()
555 .scan(
556 (String::new(), Vec::new()),
557 move |(buffer, utf8_buffer), chunk| {
558 let result = match chunk {
559 Ok(bytes) => {
560 utf8_buffer.extend_from_slice(&bytes);
561
562 match String::from_utf8(utf8_buffer.clone()) {
563 Ok(text) => {
564 buffer.push_str(&text);
565 utf8_buffer.clear();
566 }
567 Err(e) => {
568 let valid_up_to = e.utf8_error().valid_up_to();
569 if valid_up_to > 0 {
570 let valid =
573 String::from_utf8_lossy(&utf8_buffer[..valid_up_to]);
574 buffer.push_str(&valid);
575 utf8_buffer.drain(..valid_up_to);
576 }
577 }
578 }
579
580 let mut results = Vec::new();
581
582 while let Some(pos) = buffer.find("\n\n") {
583 let event = buffer[..pos + 2].to_string();
584 buffer.drain(..pos + 2);
585
586 match parser(&event) {
587 Ok(Some(content)) => results.push(Ok(content)),
588 Ok(None) => {}
589 Err(e) => results.push(Err(e)),
590 }
591 }
592
593 Some(results)
594 }
595 Err(e) => Some(vec![Err(LLMError::HttpError(e.to_string()))]),
596 };
597
598 async move { result }
599 },
600 )
601 .flat_map(futures::stream::iter);
602
603 Box::pin(stream)
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609 use bytes::Bytes;
610 use futures::stream::StreamExt;
611
612 #[tokio::test]
613 async fn test_create_sse_stream_handles_split_utf8() {
614 let test_data = "data: Positive reactions\n\n".as_bytes();
615
616 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
617 Ok(Bytes::from(&test_data[..10])),
618 Ok(Bytes::from(&test_data[10..])),
619 ];
620
621 let mock_response = create_mock_response(chunks);
622
623 let parser = |event: &str| -> Result<Option<String>, LLMError> {
624 if let Some(content) = event.strip_prefix("data: ") {
625 let content = content.trim();
626 if content.is_empty() {
627 return Ok(None);
628 }
629 Ok(Some(content.to_string()))
630 } else {
631 Ok(None)
632 }
633 };
634
635 let mut stream = create_sse_stream(mock_response, parser);
636
637 let mut results = Vec::new();
638 while let Some(result) = stream.next().await {
639 results.push(result);
640 }
641
642 assert_eq!(results.len(), 1);
643 assert_eq!(results[0].as_ref().unwrap(), "Positive reactions");
644 }
645
646 #[tokio::test]
647 async fn test_create_sse_stream_handles_split_sse_events() {
648 let event1 = "data: First event\n\n";
649 let event2 = "data: Second event\n\n";
650 let combined = format!("{}{}", event1, event2);
651 let test_data = combined.as_bytes().to_vec();
652
653 let split_point = event1.len() + 5;
654 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
655 Ok(Bytes::from(test_data[..split_point].to_vec())),
656 Ok(Bytes::from(test_data[split_point..].to_vec())),
657 ];
658
659 let mock_response = create_mock_response(chunks);
660
661 let parser = |event: &str| -> Result<Option<String>, LLMError> {
662 if let Some(content) = event.strip_prefix("data: ") {
663 let content = content.trim();
664 if content.is_empty() {
665 return Ok(None);
666 }
667 Ok(Some(content.to_string()))
668 } else {
669 Ok(None)
670 }
671 };
672
673 let mut stream = create_sse_stream(mock_response, parser);
674
675 let mut results = Vec::new();
676 while let Some(result) = stream.next().await {
677 results.push(result);
678 }
679
680 assert_eq!(results.len(), 2);
681 assert_eq!(results[0].as_ref().unwrap(), "First event");
682 assert_eq!(results[1].as_ref().unwrap(), "Second event");
683 }
684
685 #[tokio::test]
686 async fn test_create_sse_stream_handles_multibyte_utf8_split() {
687 let multibyte_char = "✨";
688 let event = format!("data: Star {}\n\n", multibyte_char);
689 let test_data = event.as_bytes().to_vec();
690
691 let emoji_start = event.find(multibyte_char).unwrap();
692 let split_in_emoji = emoji_start + 1;
693
694 let chunks: Vec<Result<Bytes, reqwest::Error>> = vec![
695 Ok(Bytes::from(test_data[..split_in_emoji].to_vec())),
696 Ok(Bytes::from(test_data[split_in_emoji..].to_vec())),
697 ];
698
699 let mock_response = create_mock_response(chunks);
700
701 let parser = |event: &str| -> Result<Option<String>, LLMError> {
702 if let Some(content) = event.strip_prefix("data: ") {
703 let content = content.trim();
704 if content.is_empty() {
705 return Ok(None);
706 }
707 Ok(Some(content.to_string()))
708 } else {
709 Ok(None)
710 }
711 };
712
713 let mut stream = create_sse_stream(mock_response, parser);
714
715 let mut results = Vec::new();
716 while let Some(result) = stream.next().await {
717 results.push(result);
718 }
719
720 assert_eq!(results.len(), 1);
721 assert_eq!(
722 results[0].as_ref().unwrap(),
723 &format!("Star {}", multibyte_char)
724 );
725 }
726
727 fn create_mock_response(chunks: Vec<Result<Bytes, reqwest::Error>>) -> reqwest::Response {
728 use http_body_util::StreamBody;
729 use reqwest::Body;
730
731 let frame_stream = futures::stream::iter(
732 chunks
733 .into_iter()
734 .map(|chunk| chunk.map(|bytes| hyper::body::Frame::data(bytes))),
735 );
736
737 let body = StreamBody::new(frame_stream);
738 let body = Body::wrap(body);
739
740 let http_response = http::Response::builder().status(200).body(body).unwrap();
741
742 http_response.into()
743 }
744}