1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::mpsc;
12
13use crate::constants::{ANTHROPIC_DEFAULT_BASE_URL, OPENAI_DEFAULT_BASE_URL};
14use crate::tools::ToolDefinition;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub struct Message {
18 pub role: Role,
19 pub content: MessageContent,
20}
21
22#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25 System,
26 User,
27 Assistant,
28 Tool,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32#[serde(untagged)]
33pub enum MessageContent {
34 Text(String),
35 Blocks(Vec<ContentBlock>),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type")]
40pub enum ContentBlock {
41 #[serde(rename = "text")]
42 Text { text: String },
43 #[serde(rename = "tool_use")]
44 ToolUse {
45 id: String,
46 name: String,
47 input: serde_json::Value,
48 },
49 #[serde(rename = "tool_result")]
50 ToolResult {
51 tool_use_id: String,
52 content: String,
53 },
54 #[serde(rename = "thinking")]
57 Thinking {
58 thinking: String,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 signature: Option<String>,
61 },
62 #[serde(rename = "server_tool_use")]
65 ServerToolUse {
66 id: String,
67 name: String,
68 input: serde_json::Value,
69 },
70 #[serde(rename = "server_tool_result")]
71 ServerToolResult {
72 tool_use_id: String,
73 content: String,
74 },
75 #[serde(rename = "web_search_tool_result")]
77 WebSearchResult {
78 tool_use_id: String,
79 content: WebSearchContent,
80 },
81}
82
83impl PartialEq for ContentBlock {
85 fn eq(&self, other: &Self) -> bool {
86 match (self, other) {
87 (ContentBlock::Text { text: a }, ContentBlock::Text { text: b }) => a == b,
88 (ContentBlock::ToolUse { id: a, name: b, .. }, ContentBlock::ToolUse { id: c, name: d, .. }) => {
89 a == c && b == d
90 },
91 (ContentBlock::ToolResult { tool_use_id: a, content: b }, ContentBlock::ToolResult { tool_use_id: c, content: d }) => {
92 a == c && b == d
93 },
94 (ContentBlock::Thinking { thinking: a, signature: b }, ContentBlock::Thinking { thinking: c, signature: d }) => {
95 a == c && b == d
96 },
97 (ContentBlock::ServerToolUse { id: a, name: b, .. }, ContentBlock::ServerToolUse { id: c, name: d, .. }) => {
98 a == c && b == d
99 },
100 (ContentBlock::ServerToolResult { tool_use_id: a, content: b }, ContentBlock::ServerToolResult { tool_use_id: c, content: d }) => {
101 a == c && b == d
102 },
103 (ContentBlock::WebSearchResult { tool_use_id: a, .. }, ContentBlock::WebSearchResult { tool_use_id: b, .. }) => {
104 a == b },
106 _ => false,
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
113pub struct WebSearchContent {
114 pub results: Vec<WebSearchResultItem>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
118pub struct WebSearchResultItem {
119 pub title: Option<String>,
120 pub url: String,
121 pub encrypted_content: Option<String>,
122 pub snippet: Option<String>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ServerTool {
129 #[serde(rename = "type")]
130 pub tool_type: String,
131 pub name: String,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub max_uses: Option<u32>,
134}
135
136impl ServerTool {
137 pub fn web_search(max_uses: Option<u32>) -> Self {
139 Self {
140 tool_type: "web_search_tool".to_string(),
141 name: "web_search".to_string(),
142 max_uses,
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
148pub struct ChatRequest {
149 pub messages: Vec<Message>,
150 pub tools: Vec<ToolDefinition>,
151 pub system: Option<String>,
152 pub think: bool,
153 pub max_tokens: u32,
155 pub server_tools: Vec<ServerTool>,
157 pub enable_caching: bool,
159}
160
161#[derive(Debug, Clone)]
162pub struct ChatResponse {
163 pub content: Vec<ContentBlock>,
164 pub stop_reason: StopReason,
165 pub usage: Usage,
166}
167
168#[derive(Debug, Clone, Default, PartialEq, Eq)]
172pub struct Usage {
173 pub input_tokens: u32,
174 pub output_tokens: u32,
175 pub cache_creation_input_tokens: u32,
176 pub cache_read_input_tokens: u32,
177}
178
179#[derive(Debug, Clone, PartialEq)]
180pub enum StopReason {
181 EndTurn,
182 ToolUse,
183 MaxTokens,
184}
185
186#[derive(Debug, Clone)]
188pub enum StreamEvent {
189 FirstByte,
191 ThinkingDelta(String),
193 TextDelta(String),
195 ToolUseStart { id: String, name: String },
197 ToolInputDelta { bytes_so_far: usize },
202 ToolInputComplete {
204 id: String,
205 name: String,
206 input: serde_json::Value,
207 },
208 Usage { output_tokens: u32 },
210 Done(ChatResponse),
212 Error(String),
214}
215
216#[async_trait]
217pub trait Provider: Send + Sync {
218 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
219
220 fn context_size(&self) -> Option<u32> {
224 None
225 }
226
227 fn model_name(&self) -> &str {
229 "unknown"
230 }
231
232 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
235 let (tx, rx) = mpsc::channel(32);
236 let response = self.chat(request).await?;
237 let _ = tx.send(StreamEvent::FirstByte).await;
238 for block in &response.content {
239 match block {
240 ContentBlock::Thinking { thinking, .. } => {
241 let _ = tx.send(StreamEvent::ThinkingDelta(thinking.clone())).await;
242 }
243 ContentBlock::Text { text } => {
244 let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
245 }
246 ContentBlock::ToolUse { id, name, input } => {
247 let _ = tx
248 .send(StreamEvent::ToolUseStart {
249 id: id.clone(),
250 name: name.clone(),
251 })
252 .await;
253 let _ = tx
254 .send(StreamEvent::ToolInputComplete {
255 id: id.clone(),
256 name: name.clone(),
257 input: input.clone(),
258 })
259 .await;
260 }
261 _ => {}
262 }
263 }
264 let _ = tx.send(StreamEvent::Done(response)).await;
265 Ok(rx)
266 }
267
268 fn clone_box(&self) -> Box<dyn Provider>;
270
271 fn clone_arc(&self) -> Arc<dyn Provider>;
273}
274
275impl Clone for Box<dyn Provider> {
276 fn clone(&self) -> Self {
277 self.clone_box()
278 }
279}
280
281#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub enum ProviderType {
288 Anthropic,
289 OpenAI,
290}
291
292pub fn create_provider(
295 provider_type: ProviderType,
296 api_key: String,
297 model: String,
298 base_url: Option<String>,
299) -> Result<Box<dyn Provider>> {
300 create_provider_with_headers(provider_type, api_key, model, base_url, None)
301}
302
303pub fn create_provider_with_headers(
305 provider_type: ProviderType,
306 api_key: String,
307 model: String,
308 base_url: Option<String>,
309 extra_headers: Option<std::collections::HashMap<String, String>>,
310) -> Result<Box<dyn Provider>> {
311 match provider_type {
312 ProviderType::Anthropic => {
313 let provider = anthropic::AnthropicProvider::with_headers(
314 api_key,
315 model,
316 base_url.unwrap_or_else(|| ANTHROPIC_DEFAULT_BASE_URL.to_string()),
317 extra_headers,
318 );
319 Ok(Box::new(provider))
320 }
321 ProviderType::OpenAI => {
322 let provider = openai::OpenAIProvider::with_headers(
323 api_key,
324 model,
325 base_url.unwrap_or_else(|| OPENAI_DEFAULT_BASE_URL.to_string()),
326 extra_headers,
327 );
328 Ok(Box::new(provider))
329 }
330 }
331}
332
333pub fn create_minimal_provider(model: &str) -> Box<dyn Provider> {
337 let config = crate::config::MatrixConfig::load();
339
340 let api_key = config
342 .resolve_api_key()
343 .unwrap_or_else(|| {
344 std::env::var("API_KEY")
346 .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
347 .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
348 .or_else(|_| std::env::var("OPENAI_API_KEY"))
349 .unwrap_or_default()
350 });
351
352 if api_key.is_empty() {
353 panic!("Failed to create minimal provider: no API key configured. \
354 Please set API_KEY env var or configure ~/.matrix/config.json")
355 }
356
357 let base_url = config
359 .resolve_base_url()
360 .or_else(|| {
361 std::env::var("BASE_URL")
362 .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
363 .ok()
364 });
365
366 let provider_type = config.resolve_provider_type(model);
368
369 create_provider_with_headers(
371 provider_type,
372 api_key,
373 model.to_string(),
374 base_url,
375 config.extra_headers.clone(),
376 )
377 .unwrap_or_else(|e| {
378 panic!("Failed to create minimal provider for model '{}': {}", model, e)
379 })
380}
381
382pub fn infer_provider_type(model: &str) -> ProviderType {
385 let lower = model.to_lowercase();
386 if lower.contains("claude")
387 || lower.contains("opus")
388 || lower.contains("sonnet")
389 || lower.contains("haiku")
390 {
391 ProviderType::Anthropic
392 } else if lower.contains("gpt")
393 || lower.contains("o1")
394 || lower.contains("o3")
395 || lower.contains("o4")
396 {
397 ProviderType::OpenAI
398 } else {
399 ProviderType::Anthropic
401 }
402}