matrixcode_core/providers/
mod.rs1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use tokio::sync::mpsc;
11
12use crate::tools::ToolDefinition;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Message {
16 pub role: Role,
17 pub content: MessageContent,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 System,
24 User,
25 Assistant,
26 Tool,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(untagged)]
31pub enum MessageContent {
32 Text(String),
33 Blocks(Vec<ContentBlock>),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37#[serde(tag = "type")]
38pub enum ContentBlock {
39 #[serde(rename = "text")]
40 Text { text: String },
41 #[serde(rename = "tool_use")]
42 ToolUse {
43 id: String,
44 name: String,
45 input: serde_json::Value,
46 },
47 #[serde(rename = "tool_result")]
48 ToolResult {
49 tool_use_id: String,
50 content: String,
51 },
52 #[serde(rename = "thinking")]
55 Thinking {
56 thinking: String,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 signature: Option<String>,
59 },
60 #[serde(rename = "server_tool_use")]
63 ServerToolUse {
64 id: String,
65 name: String,
66 input: serde_json::Value,
67 },
68 #[serde(rename = "web_search_tool_result")]
70 WebSearchResult {
71 tool_use_id: String,
72 content: WebSearchContent,
73 },
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
78pub struct WebSearchContent {
79 pub results: Vec<WebSearchResultItem>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
83pub struct WebSearchResultItem {
84 pub title: Option<String>,
85 pub url: String,
86 pub encrypted_content: Option<String>,
87 pub snippet: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ServerTool {
94 #[serde(rename = "type")]
95 pub tool_type: String,
96 pub name: String,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub max_uses: Option<u32>,
99}
100
101impl ServerTool {
102 pub fn web_search(max_uses: Option<u32>) -> Self {
104 Self {
105 tool_type: "web_search_tool".to_string(),
106 name: "web_search".to_string(),
107 max_uses,
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
113pub struct ChatRequest {
114 pub messages: Vec<Message>,
115 pub tools: Vec<ToolDefinition>,
116 pub system: Option<String>,
117 pub think: bool,
118 pub max_tokens: u32,
120 pub server_tools: Vec<ServerTool>,
122 pub enable_caching: bool,
124}
125
126#[derive(Debug, Clone)]
127pub struct ChatResponse {
128 pub content: Vec<ContentBlock>,
129 pub stop_reason: StopReason,
130 pub usage: Usage,
131}
132
133#[derive(Debug, Clone, Default, PartialEq, Eq)]
137pub struct Usage {
138 pub input_tokens: u32,
139 pub output_tokens: u32,
140 pub cache_creation_input_tokens: u32,
141 pub cache_read_input_tokens: u32,
142}
143
144#[derive(Debug, Clone, PartialEq)]
145pub enum StopReason {
146 EndTurn,
147 ToolUse,
148 MaxTokens,
149}
150
151#[derive(Debug, Clone)]
153pub enum StreamEvent {
154 FirstByte,
156 ThinkingDelta(String),
158 TextDelta(String),
160 ToolUseStart { id: String, name: String },
162 ToolInputDelta { bytes_so_far: usize },
167 Usage { output_tokens: u32 },
169 Done(ChatResponse),
171 Error(String),
173}
174
175#[async_trait]
176pub trait Provider: Send + Sync {
177 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
178
179 fn context_size(&self) -> Option<u32> {
183 None
184 }
185
186 fn model_name(&self) -> &str {
188 "unknown"
189 }
190
191 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
194 let (tx, rx) = mpsc::channel(32);
195 let response = self.chat(request).await?;
196 let _ = tx.send(StreamEvent::FirstByte).await;
197 for block in &response.content {
198 if let ContentBlock::Text { text } = block {
199 let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
200 }
201 }
202 let _ = tx.send(StreamEvent::Done(response)).await;
203 Ok(rx)
204 }
205
206 fn clone_box(&self) -> Box<dyn Provider>;
208}
209
210impl Clone for Box<dyn Provider> {
211 fn clone(&self) -> Self {
212 self.clone_box()
213 }
214}
215
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
222pub enum ProviderType {
223 Anthropic,
224 OpenAI,
225}
226
227pub fn create_provider(
230 provider_type: ProviderType,
231 api_key: String,
232 model: String,
233 base_url: Option<String>,
234) -> Result<Box<dyn Provider>> {
235 create_provider_with_headers(provider_type, api_key, model, base_url, None)
236}
237
238pub fn create_provider_with_headers(
240 provider_type: ProviderType,
241 api_key: String,
242 model: String,
243 base_url: Option<String>,
244 extra_headers: Option<std::collections::HashMap<String, String>>,
245) -> Result<Box<dyn Provider>> {
246 match provider_type {
247 ProviderType::Anthropic => {
248 let provider = anthropic::AnthropicProvider::with_headers(
249 api_key,
250 model,
251 base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
252 extra_headers,
253 );
254 Ok(Box::new(provider))
255 }
256 ProviderType::OpenAI => {
257 let provider = openai::OpenAIProvider::with_headers(
258 api_key,
259 model,
260 base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
261 extra_headers,
262 );
263 Ok(Box::new(provider))
264 }
265 }
266}
267
268pub fn create_minimal_provider(model: &str) -> Box<dyn Provider> {
271 let _ = dotenvy::dotenv();
273
274 let api_key = std::env::var("API_KEY")
276 .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
277 .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
278 .or_else(|_| std::env::var("OPENAI_API_KEY"))
279 .unwrap_or_default();
280
281 let base_url = std::env::var("BASE_URL")
283 .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
284 .ok();
285
286 let provider_type = infer_provider_type(model);
288
289 create_provider_with_headers(provider_type, api_key, model.to_string(), base_url, None)
291 .unwrap_or_else(|_| {
292 panic!("Failed to create minimal provider for background task: no API key configured")
295 })
296}
297
298pub fn infer_provider_type(model: &str) -> ProviderType {
301 let lower = model.to_lowercase();
302 if lower.contains("claude")
303 || lower.contains("opus")
304 || lower.contains("sonnet")
305 || lower.contains("haiku")
306 {
307 ProviderType::Anthropic
308 } else if lower.contains("gpt")
309 || lower.contains("o1")
310 || lower.contains("o3")
311 || lower.contains("o4")
312 {
313 ProviderType::OpenAI
314 } else {
315 ProviderType::Anthropic
317 }
318}