matrixcode_core/providers/
mod.rs1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::mpsc;
12
13use crate::constants::{ANTHROPIC_DEFAULT_BASE_URL, OPENAI_DEFAULT_BASE_URL};
14use crate::tools::ToolDefinition;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18 pub role: Role,
19 pub content: MessageContent,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25 System,
26 User,
27 Assistant,
28 Tool,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32#[serde(untagged)]
33pub enum MessageContent {
34 Text(String),
35 Blocks(Vec<ContentBlock>),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39#[serde(tag = "type")]
40pub enum ContentBlock {
41 #[serde(rename = "text")]
42 Text { text: String },
43 #[serde(rename = "tool_use")]
44 ToolUse {
45 id: String,
46 name: String,
47 input: serde_json::Value,
48 },
49 #[serde(rename = "tool_result")]
50 ToolResult {
51 tool_use_id: String,
52 content: String,
53 },
54 #[serde(rename = "thinking")]
57 Thinking {
58 thinking: String,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 signature: Option<String>,
61 },
62 #[serde(rename = "server_tool_use")]
65 ServerToolUse {
66 id: String,
67 name: String,
68 input: serde_json::Value,
69 },
70 #[serde(rename = "web_search_tool_result")]
72 WebSearchResult {
73 tool_use_id: String,
74 content: WebSearchContent,
75 },
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct WebSearchContent {
81 pub results: Vec<WebSearchResultItem>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85pub struct WebSearchResultItem {
86 pub title: Option<String>,
87 pub url: String,
88 pub encrypted_content: Option<String>,
89 pub snippet: Option<String>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ServerTool {
96 #[serde(rename = "type")]
97 pub tool_type: String,
98 pub name: String,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub max_uses: Option<u32>,
101}
102
103impl ServerTool {
104 pub fn web_search(max_uses: Option<u32>) -> Self {
106 Self {
107 tool_type: "web_search_tool".to_string(),
108 name: "web_search".to_string(),
109 max_uses,
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
115pub struct ChatRequest {
116 pub messages: Vec<Message>,
117 pub tools: Vec<ToolDefinition>,
118 pub system: Option<String>,
119 pub think: bool,
120 pub max_tokens: u32,
122 pub server_tools: Vec<ServerTool>,
124 pub enable_caching: bool,
126}
127
128#[derive(Debug, Clone)]
129pub struct ChatResponse {
130 pub content: Vec<ContentBlock>,
131 pub stop_reason: StopReason,
132 pub usage: Usage,
133}
134
135#[derive(Debug, Clone, Default, PartialEq, Eq)]
139pub struct Usage {
140 pub input_tokens: u32,
141 pub output_tokens: u32,
142 pub cache_creation_input_tokens: u32,
143 pub cache_read_input_tokens: u32,
144}
145
146#[derive(Debug, Clone, PartialEq)]
147pub enum StopReason {
148 EndTurn,
149 ToolUse,
150 MaxTokens,
151}
152
153#[derive(Debug, Clone)]
155pub enum StreamEvent {
156 FirstByte,
158 ThinkingDelta(String),
160 TextDelta(String),
162 ToolUseStart { id: String, name: String },
164 ToolInputDelta { bytes_so_far: usize },
169 Usage { output_tokens: u32 },
171 Done(ChatResponse),
173 Error(String),
175}
176
177#[async_trait]
178pub trait Provider: Send + Sync {
179 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
180
181 fn context_size(&self) -> Option<u32> {
185 None
186 }
187
188 fn model_name(&self) -> &str {
190 "unknown"
191 }
192
193 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
196 let (tx, rx) = mpsc::channel(32);
197 let response = self.chat(request).await?;
198 let _ = tx.send(StreamEvent::FirstByte).await;
199 for block in &response.content {
200 if let ContentBlock::Text { text } = block {
201 let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
202 }
203 }
204 let _ = tx.send(StreamEvent::Done(response)).await;
205 Ok(rx)
206 }
207
208 fn clone_box(&self) -> Box<dyn Provider>;
210
211 fn clone_arc(&self) -> Arc<dyn Provider>;
213}
214
215impl Clone for Box<dyn Provider> {
216 fn clone(&self) -> Self {
217 self.clone_box()
218 }
219}
220
221#[derive(Debug, Clone, Copy, PartialEq, Eq)]
227pub enum ProviderType {
228 Anthropic,
229 OpenAI,
230}
231
232pub fn create_provider(
235 provider_type: ProviderType,
236 api_key: String,
237 model: String,
238 base_url: Option<String>,
239) -> Result<Box<dyn Provider>> {
240 create_provider_with_headers(provider_type, api_key, model, base_url, None)
241}
242
243pub fn create_provider_with_headers(
245 provider_type: ProviderType,
246 api_key: String,
247 model: String,
248 base_url: Option<String>,
249 extra_headers: Option<std::collections::HashMap<String, String>>,
250) -> Result<Box<dyn Provider>> {
251 match provider_type {
252 ProviderType::Anthropic => {
253 let provider = anthropic::AnthropicProvider::with_headers(
254 api_key,
255 model,
256 base_url.unwrap_or_else(|| ANTHROPIC_DEFAULT_BASE_URL.to_string()),
257 extra_headers,
258 );
259 Ok(Box::new(provider))
260 }
261 ProviderType::OpenAI => {
262 let provider = openai::OpenAIProvider::with_headers(
263 api_key,
264 model,
265 base_url.unwrap_or_else(|| OPENAI_DEFAULT_BASE_URL.to_string()),
266 extra_headers,
267 );
268 Ok(Box::new(provider))
269 }
270 }
271}
272
273pub fn create_minimal_provider(model: &str) -> Box<dyn Provider> {
276 let _ = dotenvy::dotenv();
278
279 let api_key = std::env::var("API_KEY")
281 .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
282 .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
283 .or_else(|_| std::env::var("OPENAI_API_KEY"))
284 .unwrap_or_default();
285
286 let base_url = std::env::var("BASE_URL")
288 .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
289 .ok();
290
291 let provider_type = infer_provider_type(model);
293
294 create_provider_with_headers(provider_type, api_key, model.to_string(), base_url, None)
296 .unwrap_or_else(|_| {
297 panic!("Failed to create minimal provider for background task: no API key configured")
300 })
301}
302
303pub fn infer_provider_type(model: &str) -> ProviderType {
306 let lower = model.to_lowercase();
307 if lower.contains("claude")
308 || lower.contains("opus")
309 || lower.contains("sonnet")
310 || lower.contains("haiku")
311 {
312 ProviderType::Anthropic
313 } else if lower.contains("gpt")
314 || lower.contains("o1")
315 || lower.contains("o3")
316 || lower.contains("o4")
317 {
318 ProviderType::OpenAI
319 } else {
320 ProviderType::Anthropic
322 }
323}