matrixcode_core/providers/
mod.rs1pub mod anthropic;
2pub mod openai;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use tokio::sync::mpsc;
8
9use crate::tools::ToolDefinition;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Message {
13 pub role: Role,
14 pub content: MessageContent,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(rename_all = "lowercase")]
19pub enum Role {
20 System,
21 User,
22 Assistant,
23 Tool,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27#[serde(untagged)]
28pub enum MessageContent {
29 Text(String),
30 Blocks(Vec<ContentBlock>),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34#[serde(tag = "type")]
35pub enum ContentBlock {
36 #[serde(rename = "text")]
37 Text { text: String },
38 #[serde(rename = "tool_use")]
39 ToolUse {
40 id: String,
41 name: String,
42 input: serde_json::Value,
43 },
44 #[serde(rename = "tool_result")]
45 ToolResult {
46 tool_use_id: String,
47 content: String,
48 },
49 #[serde(rename = "thinking")]
52 Thinking {
53 thinking: String,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 signature: Option<String>,
56 },
57 #[serde(rename = "server_tool_use")]
60 ServerToolUse {
61 id: String,
62 name: String,
63 input: serde_json::Value,
64 },
65 #[serde(rename = "web_search_tool_result")]
67 WebSearchResult {
68 tool_use_id: String,
69 content: WebSearchContent,
70 },
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
75pub struct WebSearchContent {
76 pub results: Vec<WebSearchResultItem>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct WebSearchResultItem {
81 pub title: Option<String>,
82 pub url: String,
83 pub encrypted_content: Option<String>,
84 pub snippet: Option<String>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ServerTool {
91 #[serde(rename = "type")]
92 pub tool_type: String,
93 pub name: String,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub max_uses: Option<u32>,
96}
97
98impl ServerTool {
99 pub fn web_search(max_uses: Option<u32>) -> Self {
101 Self {
102 tool_type: "web_search_tool".to_string(),
103 name: "web_search".to_string(),
104 max_uses,
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
110pub struct ChatRequest {
111 pub messages: Vec<Message>,
112 pub tools: Vec<ToolDefinition>,
113 pub system: Option<String>,
114 pub think: bool,
115 pub max_tokens: u32,
117 pub server_tools: Vec<ServerTool>,
119 pub enable_caching: bool,
121}
122
123#[derive(Debug, Clone)]
124pub struct ChatResponse {
125 pub content: Vec<ContentBlock>,
126 pub stop_reason: StopReason,
127 pub usage: Usage,
128}
129
130#[derive(Debug, Clone, Default, PartialEq, Eq)]
134pub struct Usage {
135 pub input_tokens: u32,
136 pub output_tokens: u32,
137 pub cache_creation_input_tokens: u32,
138 pub cache_read_input_tokens: u32,
139}
140
141#[derive(Debug, Clone, PartialEq)]
142pub enum StopReason {
143 EndTurn,
144 ToolUse,
145 MaxTokens,
146}
147
148#[derive(Debug, Clone)]
150pub enum StreamEvent {
151 FirstByte,
153 ThinkingDelta(String),
155 TextDelta(String),
157 ToolUseStart { id: String, name: String },
159 ToolInputDelta { bytes_so_far: usize },
164 Usage { output_tokens: u32 },
166 Done(ChatResponse),
168 Error(String),
170}
171
172#[async_trait]
173pub trait Provider: Send + Sync {
174 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
175
176 fn context_size(&self) -> Option<u32> {
180 None
181 }
182
183 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
186 let (tx, rx) = mpsc::channel(32);
187 let response = self.chat(request).await?;
188 let _ = tx.send(StreamEvent::FirstByte).await;
189 for block in &response.content {
190 if let ContentBlock::Text { text } = block {
191 let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
192 }
193 }
194 let _ = tx.send(StreamEvent::Done(response)).await;
195 Ok(rx)
196 }
197
198 fn clone_box(&self) -> Box<dyn Provider>;
200}
201
202impl Clone for Box<dyn Provider> {
203 fn clone(&self) -> Self {
204 self.clone_box()
205 }
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
214pub enum ProviderType {
215 Anthropic,
216 OpenAI,
217}
218
219pub fn create_provider(
222 provider_type: ProviderType,
223 api_key: String,
224 model: String,
225 base_url: Option<String>,
226) -> Result<Box<dyn Provider>> {
227 create_provider_with_headers(provider_type, api_key, model, base_url, None)
228}
229
230pub fn create_provider_with_headers(
232 provider_type: ProviderType,
233 api_key: String,
234 model: String,
235 base_url: Option<String>,
236 extra_headers: Option<std::collections::HashMap<String, String>>,
237) -> Result<Box<dyn Provider>> {
238 match provider_type {
239 ProviderType::Anthropic => {
240 let provider = anthropic::AnthropicProvider::with_headers(
241 api_key,
242 model,
243 base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
244 extra_headers,
245 );
246 Ok(Box::new(provider))
247 }
248 ProviderType::OpenAI => {
249 let provider = openai::OpenAIProvider::with_headers(
250 api_key,
251 model,
252 base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
253 extra_headers,
254 );
255 Ok(Box::new(provider))
256 }
257 }
258}
259
260pub fn infer_provider_type(model: &str) -> ProviderType {
263 let lower = model.to_lowercase();
264 if lower.contains("claude")
265 || lower.contains("opus")
266 || lower.contains("sonnet")
267 || lower.contains("haiku")
268 {
269 ProviderType::Anthropic
270 } else if lower.contains("gpt")
271 || lower.contains("o1")
272 || lower.contains("o3")
273 || lower.contains("o4")
274 {
275 ProviderType::OpenAI
276 } else {
277 ProviderType::Anthropic
279 }
280}