matrixcode_core/providers/
mod.rs1pub mod anthropic;
2pub mod openai;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use tokio::sync::mpsc;
8
9use crate::tools::ToolDefinition;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Message {
13 pub role: Role,
14 pub content: MessageContent,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(rename_all = "lowercase")]
19pub enum Role {
20 System,
21 User,
22 Assistant,
23 Tool,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27#[serde(untagged)]
28pub enum MessageContent {
29 Text(String),
30 Blocks(Vec<ContentBlock>),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34#[serde(tag = "type")]
35pub enum ContentBlock {
36 #[serde(rename = "text")]
37 Text { text: String },
38 #[serde(rename = "tool_use")]
39 ToolUse {
40 id: String,
41 name: String,
42 input: serde_json::Value,
43 },
44 #[serde(rename = "tool_result")]
45 ToolResult {
46 tool_use_id: String,
47 content: String,
48 },
49 #[serde(rename = "thinking")]
52 Thinking {
53 thinking: String,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 signature: Option<String>,
56 },
57 #[serde(rename = "server_tool_use")]
60 ServerToolUse {
61 id: String,
62 name: String,
63 input: serde_json::Value,
64 },
65 #[serde(rename = "web_search_tool_result")]
67 WebSearchResult {
68 tool_use_id: String,
69 content: WebSearchContent,
70 },
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
75pub struct WebSearchContent {
76 pub results: Vec<WebSearchResultItem>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct WebSearchResultItem {
81 pub title: Option<String>,
82 pub url: String,
83 pub encrypted_content: Option<String>,
84 pub snippet: Option<String>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ServerTool {
91 #[serde(rename = "type")]
92 pub tool_type: String,
93 pub name: String,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub max_uses: Option<u32>,
96}
97
98impl ServerTool {
99 pub fn web_search(max_uses: Option<u32>) -> Self {
101 Self {
102 tool_type: "web_search_tool".to_string(),
103 name: "web_search".to_string(),
104 max_uses,
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
110pub struct ChatRequest {
111 pub messages: Vec<Message>,
112 pub tools: Vec<ToolDefinition>,
113 pub system: Option<String>,
114 pub think: bool,
115 pub max_tokens: u32,
117 pub server_tools: Vec<ServerTool>,
119 pub enable_caching: bool,
121}
122
123#[derive(Debug, Clone)]
124pub struct ChatResponse {
125 pub content: Vec<ContentBlock>,
126 pub stop_reason: StopReason,
127 pub usage: Usage,
128}
129
130#[derive(Debug, Clone, Default, PartialEq, Eq)]
134pub struct Usage {
135 pub input_tokens: u32,
136 pub output_tokens: u32,
137 pub cache_creation_input_tokens: u32,
138 pub cache_read_input_tokens: u32,
139}
140
141#[derive(Debug, Clone, PartialEq)]
142pub enum StopReason {
143 EndTurn,
144 ToolUse,
145 MaxTokens,
146}
147
148#[derive(Debug, Clone)]
150pub enum StreamEvent {
151 FirstByte,
153 ThinkingDelta(String),
155 TextDelta(String),
157 ToolUseStart { id: String, name: String },
159 ToolInputDelta { bytes_so_far: usize },
164 Usage { output_tokens: u32 },
166 Done(ChatResponse),
168 Error(String),
170}
171
172#[async_trait]
173pub trait Provider: Send + Sync {
174 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
175
176 fn context_size(&self) -> Option<u32> {
180 None
181 }
182
183 async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
186 let (tx, rx) = mpsc::channel(32);
187 let response = self.chat(request).await?;
188 let _ = tx.send(StreamEvent::FirstByte).await;
189 for block in &response.content {
190 if let ContentBlock::Text { text } = block {
191 let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
192 }
193 }
194 let _ = tx.send(StreamEvent::Done(response)).await;
195 Ok(rx)
196 }
197
198 fn clone_box(&self) -> Box<dyn Provider>;
200}
201
202impl Clone for Box<dyn Provider> {
203 fn clone(&self) -> Self {
204 self.clone_box()
205 }
206}
207
208#[derive(Debug, Clone, PartialEq, Eq)]
214pub enum ProviderType {
215 Anthropic,
216 OpenAI,
217}
218
219pub fn create_provider(
222 provider_type: ProviderType,
223 api_key: String,
224 model: String,
225 base_url: Option<String>,
226) -> Result<Box<dyn Provider>> {
227 match provider_type {
228 ProviderType::Anthropic => {
229 let provider = anthropic::AnthropicProvider::new(
230 api_key,
231 model,
232 base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
233 );
234 Ok(Box::new(provider))
235 }
236 ProviderType::OpenAI => {
237 let provider = openai::OpenAIProvider::new(
238 api_key,
239 model,
240 base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
241 );
242 Ok(Box::new(provider))
243 }
244 }
245}
246
247pub fn infer_provider_type(model: &str) -> ProviderType {
250 let lower = model.to_lowercase();
251 if lower.contains("claude")
252 || lower.contains("opus")
253 || lower.contains("sonnet")
254 || lower.contains("haiku")
255 {
256 ProviderType::Anthropic
257 } else if lower.contains("gpt")
258 || lower.contains("o1")
259 || lower.contains("o3")
260 || lower.contains("o4")
261 {
262 ProviderType::OpenAI
263 } else {
264 ProviderType::Anthropic
266 }
267}