matrixcode_core/providers/
mod.rs1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use tokio::sync::mpsc;
11
12use crate::tools::ToolDefinition;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Message {
16 pub role: Role,
17 pub content: MessageContent,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 System,
24 User,
25 Assistant,
26 Tool,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(untagged)]
31pub enum MessageContent {
32 Text(String),
33 Blocks(Vec<ContentBlock>),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37#[serde(tag = "type")]
38pub enum ContentBlock {
39 #[serde(rename = "text")]
40 Text { text: String },
41 #[serde(rename = "tool_use")]
42 ToolUse {
43 id: String,
44 name: String,
45 input: serde_json::Value,
46 },
47 #[serde(rename = "tool_result")]
48 ToolResult {
49 tool_use_id: String,
50 content: String,
51 },
52 #[serde(rename = "thinking")]
55 Thinking {
56 thinking: String,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 signature: Option<String>,
59 },
60 #[serde(rename = "server_tool_use")]
63 ServerToolUse {
64 id: String,
65 name: String,
66 input: serde_json::Value,
67 },
68 #[serde(rename = "web_search_tool_result")]
70 WebSearchResult {
71 tool_use_id: String,
72 content: WebSearchContent,
73 },
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
78pub struct WebSearchContent {
79 pub results: Vec<WebSearchResultItem>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
83pub struct WebSearchResultItem {
84 pub title: Option<String>,
85 pub url: String,
86 pub encrypted_content: Option<String>,
87 pub snippet: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ServerTool {
94 #[serde(rename = "type")]
95 pub tool_type: String,
96 pub name: String,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub max_uses: Option<u32>,
99}
100
101impl ServerTool {
102 pub fn web_search(max_uses: Option<u32>) -> Self {
104 Self {
105 tool_type: "web_search_tool".to_string(),
106 name: "web_search".to_string(),
107 max_uses,
108 }
109 }
110}
111
112#[derive(Debug, Clone)]
113pub struct ChatRequest {
114 pub messages: Vec<Message>,
115 pub tools: Vec<ToolDefinition>,
116 pub system: Option<String>,
117 pub think: bool,
118 pub max_tokens: u32,
120 pub server_tools: Vec<ServerTool>,
122 pub enable_caching: bool,
124}
125
126#[derive(Debug, Clone)]
127pub struct ChatResponse {
128 pub content: Vec<ContentBlock>,
129 pub stop_reason: StopReason,
130 pub usage: Usage,
131}
132
133#[derive(Debug, Clone, Default, PartialEq, Eq)]
137pub struct Usage {
138 pub input_tokens: u32,
139 pub output_tokens: u32,
140 pub cache_creation_input_tokens: u32,
141 pub cache_read_input_tokens: u32,
142}
143
144#[derive(Debug, Clone, PartialEq)]
145pub enum StopReason {
146 EndTurn,
147 ToolUse,
148 MaxTokens,
149}
150
151#[derive(Debug, Clone)]
153pub enum StreamEvent {
154 FirstByte,
156 ThinkingDelta(String),
158 TextDelta(String),
160 ToolUseStart { id: String, name: String },
162 ToolInputDelta { bytes_so_far: usize },
167 Usage { output_tokens: u32 },
169 Done(ChatResponse),
171 Error(String),
173}
174
175#[async_trait]
176pub trait Provider: Send + Sync {
177 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
178
179 fn context_size(&self) -> Option<u32> {
183 None
184 }
185
186 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
189 let (tx, rx) = mpsc::channel(32);
190 let response = self.chat(request).await?;
191 let _ = tx.send(StreamEvent::FirstByte).await;
192 for block in &response.content {
193 if let ContentBlock::Text { text } = block {
194 let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
195 }
196 }
197 let _ = tx.send(StreamEvent::Done(response)).await;
198 Ok(rx)
199 }
200
201 fn clone_box(&self) -> Box<dyn Provider>;
203}
204
205impl Clone for Box<dyn Provider> {
206 fn clone(&self) -> Self {
207 self.clone_box()
208 }
209}
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq)]
217pub enum ProviderType {
218 Anthropic,
219 OpenAI,
220}
221
222pub fn create_provider(
225 provider_type: ProviderType,
226 api_key: String,
227 model: String,
228 base_url: Option<String>,
229) -> Result<Box<dyn Provider>> {
230 create_provider_with_headers(provider_type, api_key, model, base_url, None)
231}
232
233pub fn create_provider_with_headers(
235 provider_type: ProviderType,
236 api_key: String,
237 model: String,
238 base_url: Option<String>,
239 extra_headers: Option<std::collections::HashMap<String, String>>,
240) -> Result<Box<dyn Provider>> {
241 match provider_type {
242 ProviderType::Anthropic => {
243 let provider = anthropic::AnthropicProvider::with_headers(
244 api_key,
245 model,
246 base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
247 extra_headers,
248 );
249 Ok(Box::new(provider))
250 }
251 ProviderType::OpenAI => {
252 let provider = openai::OpenAIProvider::with_headers(
253 api_key,
254 model,
255 base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
256 extra_headers,
257 );
258 Ok(Box::new(provider))
259 }
260 }
261}
262
263pub fn infer_provider_type(model: &str) -> ProviderType {
266 let lower = model.to_lowercase();
267 if lower.contains("claude")
268 || lower.contains("opus")
269 || lower.contains("sonnet")
270 || lower.contains("haiku")
271 {
272 ProviderType::Anthropic
273 } else if lower.contains("gpt")
274 || lower.contains("o1")
275 || lower.contains("o3")
276 || lower.contains("o4")
277 {
278 ProviderType::OpenAI
279 } else {
280 ProviderType::Anthropic
282 }
283}