1pub mod router;
2pub mod streaming;
3pub mod types;
4
5pub use router::{ModelRouter, ModelTier, TaskComplexity};
6pub use streaming::{StreamAccumulator, StreamBox, StreamDelta};
7pub use types::*;
8
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12
13#[async_trait]
14pub trait LlmProvider: Send + Sync {
15 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
17
18 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
25 Box::pin(async_stream::stream! {
26 match self.chat(request).await {
27 Ok(outcome) => match outcome {
28 ChatOutcome::Success(response) => {
29 for (idx, block) in response.content.iter().enumerate() {
31 match block {
32 ContentBlock::Text { text } => {
33 yield Ok(StreamDelta::TextDelta {
34 delta: text.clone(),
35 block_index: idx,
36 });
37 }
38 ContentBlock::ToolUse { id, name, input, .. } => {
39 yield Ok(StreamDelta::ToolUseStart {
40 id: id.clone(),
41 name: name.clone(),
42 block_index: idx,
43 });
44 yield Ok(StreamDelta::ToolInputDelta {
45 id: id.clone(),
46 delta: serde_json::to_string(input).unwrap_or_default(),
47 block_index: idx,
48 });
49 }
50 ContentBlock::ToolResult { .. } => {
51 }
53 }
54 }
55 yield Ok(StreamDelta::Usage(response.usage));
56 yield Ok(StreamDelta::Done {
57 stop_reason: response.stop_reason,
58 });
59 }
60 ChatOutcome::RateLimited => {
61 yield Ok(StreamDelta::Error {
62 message: "Rate limited".to_string(),
63 recoverable: true,
64 });
65 }
66 ChatOutcome::InvalidRequest(msg) => {
67 yield Ok(StreamDelta::Error {
68 message: msg,
69 recoverable: false,
70 });
71 }
72 ChatOutcome::ServerError(msg) => {
73 yield Ok(StreamDelta::Error {
74 message: msg,
75 recoverable: true,
76 });
77 }
78 },
79 Err(e) => yield Err(e),
80 }
81 })
82 }
83
84 fn model(&self) -> &str;
85 fn provider(&self) -> &'static str;
86}
87
88pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
97 let mut accumulator = StreamAccumulator::new();
98 let mut last_error: Option<(String, bool)> = None;
99
100 while let Some(result) = stream.next().await {
101 match result {
102 Ok(delta) => {
103 if let StreamDelta::Error {
104 message,
105 recoverable,
106 } = &delta
107 {
108 last_error = Some((message.clone(), *recoverable));
109 }
110 accumulator.apply(&delta);
111 }
112 Err(e) => return Err(e),
113 }
114 }
115
116 if let Some((message, recoverable)) = last_error {
118 if !recoverable {
119 return Ok(ChatOutcome::InvalidRequest(message));
120 }
121 if message.contains("Rate limited") || message.contains("rate limit") {
123 return Ok(ChatOutcome::RateLimited);
124 }
125 return Ok(ChatOutcome::ServerError(message));
126 }
127
128 let usage = accumulator.take_usage().unwrap_or(Usage {
130 input_tokens: 0,
131 output_tokens: 0,
132 });
133 let stop_reason = accumulator.take_stop_reason();
134 let content = accumulator.into_content_blocks();
135
136 Ok(ChatOutcome::Success(ChatResponse {
137 id: String::new(),
138 content,
139 model,
140 stop_reason,
141 usage,
142 }))
143}