matrixcode_core/providers/
mod.rs1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::mpsc;
12
13use crate::constants::{ANTHROPIC_DEFAULT_BASE_URL, OPENAI_DEFAULT_BASE_URL};
14use crate::tools::ToolDefinition;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18 pub role: Role,
19 pub content: MessageContent,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25 System,
26 User,
27 Assistant,
28 Tool,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32#[serde(untagged)]
33pub enum MessageContent {
34 Text(String),
35 Blocks(Vec<ContentBlock>),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39#[serde(tag = "type")]
40pub enum ContentBlock {
41 #[serde(rename = "text")]
42 Text { text: String },
43 #[serde(rename = "tool_use")]
44 ToolUse {
45 id: String,
46 name: String,
47 input: serde_json::Value,
48 },
49 #[serde(rename = "tool_result")]
50 ToolResult {
51 tool_use_id: String,
52 content: String,
53 },
54 #[serde(rename = "thinking")]
57 Thinking {
58 thinking: String,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 signature: Option<String>,
61 },
62 #[serde(rename = "server_tool_use")]
65 ServerToolUse {
66 id: String,
67 name: String,
68 input: serde_json::Value,
69 },
70 #[serde(rename = "web_search_tool_result")]
72 WebSearchResult {
73 tool_use_id: String,
74 content: WebSearchContent,
75 },
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct WebSearchContent {
81 pub results: Vec<WebSearchResultItem>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85pub struct WebSearchResultItem {
86 pub title: Option<String>,
87 pub url: String,
88 pub encrypted_content: Option<String>,
89 pub snippet: Option<String>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ServerTool {
96 #[serde(rename = "type")]
97 pub tool_type: String,
98 pub name: String,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub max_uses: Option<u32>,
101}
102
103impl ServerTool {
104 pub fn web_search(max_uses: Option<u32>) -> Self {
106 Self {
107 tool_type: "web_search_tool".to_string(),
108 name: "web_search".to_string(),
109 max_uses,
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
115pub struct ChatRequest {
116 pub messages: Vec<Message>,
117 pub tools: Vec<ToolDefinition>,
118 pub system: Option<String>,
119 pub think: bool,
120 pub max_tokens: u32,
122 pub server_tools: Vec<ServerTool>,
124 pub enable_caching: bool,
126}
127
128#[derive(Debug, Clone)]
129pub struct ChatResponse {
130 pub content: Vec<ContentBlock>,
131 pub stop_reason: StopReason,
132 pub usage: Usage,
133}
134
135#[derive(Debug, Clone, Default, PartialEq, Eq)]
139pub struct Usage {
140 pub input_tokens: u32,
141 pub output_tokens: u32,
142 pub cache_creation_input_tokens: u32,
143 pub cache_read_input_tokens: u32,
144}
145
146#[derive(Debug, Clone, PartialEq)]
147pub enum StopReason {
148 EndTurn,
149 ToolUse,
150 MaxTokens,
151}
152
153#[derive(Debug, Clone)]
155pub enum StreamEvent {
156 FirstByte,
158 ThinkingDelta(String),
160 TextDelta(String),
162 ToolUseStart { id: String, name: String },
164 ToolInputDelta { bytes_so_far: usize },
169 ToolInputComplete {
171 id: String,
172 name: String,
173 input: serde_json::Value,
174 },
175 Usage { output_tokens: u32 },
177 Done(ChatResponse),
179 Error(String),
181}
182
183#[async_trait]
184pub trait Provider: Send + Sync {
185 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
186
187 fn context_size(&self) -> Option<u32> {
191 None
192 }
193
194 fn model_name(&self) -> &str {
196 "unknown"
197 }
198
199 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
202 let (tx, rx) = mpsc::channel(32);
203 let response = self.chat(request).await?;
204 let _ = tx.send(StreamEvent::FirstByte).await;
205 for block in &response.content {
206 match block {
207 ContentBlock::Thinking { thinking, .. } => {
208 let _ = tx.send(StreamEvent::ThinkingDelta(thinking.clone())).await;
209 }
210 ContentBlock::Text { text } => {
211 let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
212 }
213 ContentBlock::ToolUse { id, name, input } => {
214 let _ = tx
215 .send(StreamEvent::ToolUseStart {
216 id: id.clone(),
217 name: name.clone(),
218 })
219 .await;
220 let _ = tx
221 .send(StreamEvent::ToolInputComplete {
222 id: id.clone(),
223 name: name.clone(),
224 input: input.clone(),
225 })
226 .await;
227 }
228 _ => {}
229 }
230 }
231 let _ = tx.send(StreamEvent::Done(response)).await;
232 Ok(rx)
233 }
234
235 fn clone_box(&self) -> Box<dyn Provider>;
237
238 fn clone_arc(&self) -> Arc<dyn Provider>;
240}
241
242impl Clone for Box<dyn Provider> {
243 fn clone(&self) -> Self {
244 self.clone_box()
245 }
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq)]
254pub enum ProviderType {
255 Anthropic,
256 OpenAI,
257}
258
259pub fn create_provider(
262 provider_type: ProviderType,
263 api_key: String,
264 model: String,
265 base_url: Option<String>,
266) -> Result<Box<dyn Provider>> {
267 create_provider_with_headers(provider_type, api_key, model, base_url, None)
268}
269
270pub fn create_provider_with_headers(
272 provider_type: ProviderType,
273 api_key: String,
274 model: String,
275 base_url: Option<String>,
276 extra_headers: Option<std::collections::HashMap<String, String>>,
277) -> Result<Box<dyn Provider>> {
278 match provider_type {
279 ProviderType::Anthropic => {
280 let provider = anthropic::AnthropicProvider::with_headers(
281 api_key,
282 model,
283 base_url.unwrap_or_else(|| ANTHROPIC_DEFAULT_BASE_URL.to_string()),
284 extra_headers,
285 );
286 Ok(Box::new(provider))
287 }
288 ProviderType::OpenAI => {
289 let provider = openai::OpenAIProvider::with_headers(
290 api_key,
291 model,
292 base_url.unwrap_or_else(|| OPENAI_DEFAULT_BASE_URL.to_string()),
293 extra_headers,
294 );
295 Ok(Box::new(provider))
296 }
297 }
298}
299
300pub fn create_minimal_provider(model: &str) -> Box<dyn Provider> {
303 let _ = dotenvy::dotenv();
305
306 let api_key = std::env::var("API_KEY")
308 .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
309 .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
310 .or_else(|_| std::env::var("OPENAI_API_KEY"))
311 .unwrap_or_default();
312
313 let base_url = std::env::var("BASE_URL")
315 .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
316 .ok();
317
318 let provider_type = infer_provider_type(model);
320
321 create_provider_with_headers(provider_type, api_key, model.to_string(), base_url, None)
323 .unwrap_or_else(|_| {
324 panic!("Failed to create minimal provider for background task: no API key configured")
327 })
328}
329
330pub fn infer_provider_type(model: &str) -> ProviderType {
333 let lower = model.to_lowercase();
334 if lower.contains("claude")
335 || lower.contains("opus")
336 || lower.contains("sonnet")
337 || lower.contains("haiku")
338 {
339 ProviderType::Anthropic
340 } else if lower.contains("gpt")
341 || lower.contains("o1")
342 || lower.contains("o3")
343 || lower.contains("o4")
344 {
345 ProviderType::OpenAI
346 } else {
347 ProviderType::Anthropic
349 }
350}