1pub mod router;
2pub mod streaming;
3pub mod types;
4
5pub use router::{ModelRouter, ModelTier, TaskComplexity};
6pub use streaming::{StreamAccumulator, StreamBox, StreamDelta};
7pub use types::*;
8
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12
13#[async_trait]
14pub trait LlmProvider: Send + Sync {
15 async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
17
18 fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
25 Box::pin(async_stream::stream! {
26 match self.chat(request).await {
27 Ok(outcome) => match outcome {
28 ChatOutcome::Success(response) => {
29 for (idx, block) in response.content.iter().enumerate() {
31 match block {
32 ContentBlock::Text { text } => {
33 yield Ok(StreamDelta::TextDelta {
34 delta: text.clone(),
35 block_index: idx,
36 });
37 }
38 ContentBlock::Thinking { thinking } => {
39 yield Ok(StreamDelta::ThinkingDelta {
40 delta: thinking.clone(),
41 block_index: idx,
42 });
43 }
44 ContentBlock::ToolUse { id, name, input, .. } => {
45 yield Ok(StreamDelta::ToolUseStart {
46 id: id.clone(),
47 name: name.clone(),
48 block_index: idx,
49 });
50 yield Ok(StreamDelta::ToolInputDelta {
51 id: id.clone(),
52 delta: serde_json::to_string(input).unwrap_or_default(),
53 block_index: idx,
54 });
55 }
56 ContentBlock::ToolResult { .. } => {
57 }
59 }
60 }
61 yield Ok(StreamDelta::Usage(response.usage));
62 yield Ok(StreamDelta::Done {
63 stop_reason: response.stop_reason,
64 });
65 }
66 ChatOutcome::RateLimited => {
67 yield Ok(StreamDelta::Error {
68 message: "Rate limited".to_string(),
69 recoverable: true,
70 });
71 }
72 ChatOutcome::InvalidRequest(msg) => {
73 yield Ok(StreamDelta::Error {
74 message: msg,
75 recoverable: false,
76 });
77 }
78 ChatOutcome::ServerError(msg) => {
79 yield Ok(StreamDelta::Error {
80 message: msg,
81 recoverable: true,
82 });
83 }
84 },
85 Err(e) => yield Err(e),
86 }
87 })
88 }
89
90 fn model(&self) -> &str;
91 fn provider(&self) -> &'static str;
92}
93
94pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
103 let mut accumulator = StreamAccumulator::new();
104 let mut last_error: Option<(String, bool)> = None;
105
106 while let Some(result) = stream.next().await {
107 match result {
108 Ok(delta) => {
109 if let StreamDelta::Error {
110 message,
111 recoverable,
112 } = &delta
113 {
114 last_error = Some((message.clone(), *recoverable));
115 }
116 accumulator.apply(&delta);
117 }
118 Err(e) => return Err(e),
119 }
120 }
121
122 if let Some((message, recoverable)) = last_error {
124 if !recoverable {
125 return Ok(ChatOutcome::InvalidRequest(message));
126 }
127 if message.contains("Rate limited") || message.contains("rate limit") {
129 return Ok(ChatOutcome::RateLimited);
130 }
131 return Ok(ChatOutcome::ServerError(message));
132 }
133
134 let usage = accumulator.take_usage().unwrap_or(Usage {
136 input_tokens: 0,
137 output_tokens: 0,
138 });
139 let stop_reason = accumulator.take_stop_reason();
140 let content = accumulator.into_content_blocks();
141
142 Ok(ChatOutcome::Success(ChatResponse {
143 id: String::new(),
144 content,
145 model,
146 stop_reason,
147 usage,
148 }))
149}