Skip to main content

matrixcode_core/providers/
mod.rs

1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::mpsc;
12
13use crate::constants::{ANTHROPIC_DEFAULT_BASE_URL, OPENAI_DEFAULT_BASE_URL};
14use crate::tools::ToolDefinition;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18    pub role: Role,
19    pub content: MessageContent,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25    System,
26    User,
27    Assistant,
28    Tool,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32#[serde(untagged)]
33pub enum MessageContent {
34    Text(String),
35    Blocks(Vec<ContentBlock>),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39#[serde(tag = "type")]
40pub enum ContentBlock {
41    #[serde(rename = "text")]
42    Text { text: String },
43    #[serde(rename = "tool_use")]
44    ToolUse {
45        id: String,
46        name: String,
47        input: serde_json::Value,
48    },
49    #[serde(rename = "tool_result")]
50    ToolResult {
51        tool_use_id: String,
52        content: String,
53    },
54    /// Anthropic extended-thinking block. `signature` is required when sending
55    /// the block back to the API in a follow-up turn.
56    #[serde(rename = "thinking")]
57    Thinking {
58        thinking: String,
59        #[serde(skip_serializing_if = "Option::is_none")]
60        signature: Option<String>,
61    },
62    /// Server-side tool use (e.g., web_search_tool). The server executes
63    /// the tool and returns results directly without client intervention.
64    #[serde(rename = "server_tool_use")]
65    ServerToolUse {
66        id: String,
67        name: String,
68        input: serde_json::Value,
69    },
70    /// Result from a server-side web search tool.
71    #[serde(rename = "web_search_tool_result")]
72    WebSearchResult {
73        tool_use_id: String,
74        content: WebSearchContent,
75    },
76}
77
78/// Content returned by the server-side web search tool.
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct WebSearchContent {
81    pub results: Vec<WebSearchResultItem>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85pub struct WebSearchResultItem {
86    pub title: Option<String>,
87    pub url: String,
88    pub encrypted_content: Option<String>,
89    pub snippet: Option<String>,
90}
91
92/// Server-side tool definition. These tools are executed by the API provider
93/// rather than by the client. Currently only web_search_tool is supported.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ServerTool {
96    #[serde(rename = "type")]
97    pub tool_type: String,
98    pub name: String,
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub max_uses: Option<u32>,
101}
102
103impl ServerTool {
104    /// Create a new web search server tool.
105    pub fn web_search(max_uses: Option<u32>) -> Self {
106        Self {
107            tool_type: "web_search_tool".to_string(),
108            name: "web_search".to_string(),
109            max_uses,
110        }
111    }
112}
113
114#[derive(Debug, Clone)]
115pub struct ChatRequest {
116    pub messages: Vec<Message>,
117    pub tools: Vec<ToolDefinition>,
118    pub system: Option<String>,
119    pub think: bool,
120    /// Maximum output tokens for the response.
121    pub max_tokens: u32,
122    /// Server-side tools that are executed by the API provider.
123    pub server_tools: Vec<ServerTool>,
124    /// Enable prompt caching for Anthropic provider.
125    pub enable_caching: bool,
126}
127
128#[derive(Debug, Clone)]
129pub struct ChatResponse {
130    pub content: Vec<ContentBlock>,
131    pub stop_reason: StopReason,
132    pub usage: Usage,
133}
134
135/// Token accounting for one provider turn. `input_tokens` already includes
136/// cached/uncached portions combined — when providers expose cache details
137/// separately we capture them here so callers can report cache effectiveness.
138#[derive(Debug, Clone, Default, PartialEq, Eq)]
139pub struct Usage {
140    pub input_tokens: u32,
141    pub output_tokens: u32,
142    pub cache_creation_input_tokens: u32,
143    pub cache_read_input_tokens: u32,
144}
145
146#[derive(Debug, Clone, PartialEq)]
147pub enum StopReason {
148    EndTurn,
149    ToolUse,
150    MaxTokens,
151}
152
153/// Incremental events emitted during a streaming chat turn.
154#[derive(Debug, Clone)]
155pub enum StreamEvent {
156    /// First byte received from the server — agent should stop any waiting spinner.
157    FirstByte,
158    /// Extended-thinking text delta (Anthropic thinking block).
159    ThinkingDelta(String),
160    /// Visible assistant text delta.
161    TextDelta(String),
162    /// A new tool_use block started.
163    ToolUseStart { id: String, name: String },
164    /// Incremental progress for the current tool_use block's JSON input.
165    /// `bytes_so_far` is the total accumulated size of the partial JSON
166    /// received for this block — useful for driving progress indicators
167    /// while the model streams large arguments (e.g. a full file body).
168    ToolInputDelta { bytes_so_far: usize },
169    /// Complete tool input assembled before the final Done event.
170    ToolInputComplete {
171        id: String,
172        name: String,
173        input: serde_json::Value,
174    },
175    /// Real-time usage update (output tokens so far).
176    Usage { output_tokens: u32 },
177    /// Final turn result — includes the full assembled content blocks.
178    Done(ChatResponse),
179    /// Fatal error during streaming.
180    Error(String),
181}
182
183#[async_trait]
184pub trait Provider: Send + Sync {
185    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
186
187    /// Best-effort context window size (in tokens) for the configured model.
188    /// `None` if the provider cannot infer it; callers should treat that as
189    /// "don't render a fullness bar".
190    fn context_size(&self) -> Option<u32> {
191        None
192    }
193
194    /// Get the model name for this provider.
195    fn model_name(&self) -> &str {
196        "unknown"
197    }
198
199    /// Stream a chat turn. Default impl wraps `chat` and emits one `Done` event,
200    /// so providers without native streaming still work (no incremental thinking).
201    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
202        let (tx, rx) = mpsc::channel(32);
203        let response = self.chat(request).await?;
204        let _ = tx.send(StreamEvent::FirstByte).await;
205        for block in &response.content {
206            match block {
207                ContentBlock::Thinking { thinking, .. } => {
208                    let _ = tx.send(StreamEvent::ThinkingDelta(thinking.clone())).await;
209                }
210                ContentBlock::Text { text } => {
211                    let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
212                }
213                ContentBlock::ToolUse { id, name, input } => {
214                    let _ = tx
215                        .send(StreamEvent::ToolUseStart {
216                            id: id.clone(),
217                            name: name.clone(),
218                        })
219                        .await;
220                    let _ = tx
221                        .send(StreamEvent::ToolInputComplete {
222                            id: id.clone(),
223                            name: name.clone(),
224                            input: input.clone(),
225                        })
226                        .await;
227                }
228                _ => {}
229            }
230        }
231        let _ = tx.send(StreamEvent::Done(response)).await;
232        Ok(rx)
233    }
234
235    /// Clone the provider into a boxed type.
236    fn clone_box(&self) -> Box<dyn Provider>;
237
238    /// Clone the provider into an Arc type (preferred for tool system).
239    fn clone_arc(&self) -> Arc<dyn Provider>;
240}
241
242impl Clone for Box<dyn Provider> {
243    fn clone(&self) -> Self {
244        self.clone_box()
245    }
246}
247
248// ============================================================================
249// Provider Factory
250// ============================================================================
251
252/// Provider type enumeration for factory creation.
253#[derive(Debug, Clone, Copy, PartialEq, Eq)]
254pub enum ProviderType {
255    Anthropic,
256    OpenAI,
257}
258
259/// Create a provider instance based on type and configuration.
260/// This is the recommended way to obtain a Provider instance.
261pub fn create_provider(
262    provider_type: ProviderType,
263    api_key: String,
264    model: String,
265    base_url: Option<String>,
266) -> Result<Box<dyn Provider>> {
267    create_provider_with_headers(provider_type, api_key, model, base_url, None)
268}
269
270/// Create a provider with extra headers support.
271pub fn create_provider_with_headers(
272    provider_type: ProviderType,
273    api_key: String,
274    model: String,
275    base_url: Option<String>,
276    extra_headers: Option<std::collections::HashMap<String, String>>,
277) -> Result<Box<dyn Provider>> {
278    match provider_type {
279        ProviderType::Anthropic => {
280            let provider = anthropic::AnthropicProvider::with_headers(
281                api_key,
282                model,
283                base_url.unwrap_or_else(|| ANTHROPIC_DEFAULT_BASE_URL.to_string()),
284                extra_headers,
285            );
286            Ok(Box::new(provider))
287        }
288        ProviderType::OpenAI => {
289            let provider = openai::OpenAIProvider::with_headers(
290                api_key,
291                model,
292                base_url.unwrap_or_else(|| OPENAI_DEFAULT_BASE_URL.to_string()),
293                extra_headers,
294            );
295            Ok(Box::new(provider))
296        }
297    }
298}
299
300/// Create a minimal provider from environment variables (for background tasks).
301/// Uses global config for API key and base URL, suitable for non-blocking background operations.
302pub fn create_minimal_provider(model: &str) -> Box<dyn Provider> {
303    // Try to load .env first (background tasks may not have .env loaded)
304    let _ = dotenvy::dotenv();
305
306    // Get API key from env (try multiple env vars)
307    let api_key = std::env::var("API_KEY")
308        .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
309        .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
310        .or_else(|_| std::env::var("OPENAI_API_KEY"))
311        .unwrap_or_default();
312
313    // Get base URL from env
314    let base_url = std::env::var("BASE_URL")
315        .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
316        .ok();
317
318    // Infer provider type from model name
319    let provider_type = infer_provider_type(model);
320
321    // Create provider (ignore errors, return a default)
322    create_provider_with_headers(provider_type, api_key, model.to_string(), base_url, None)
323        .unwrap_or_else(|_| {
324            // Fallback: create a dummy provider that returns empty
325            // This won't actually work, but prevents crashes
326            panic!("Failed to create minimal provider for background task: no API key configured")
327        })
328}
329
330/// Infer provider type from model name.
331/// Returns Anthropic for Claude models, OpenAI for GPT models.
332pub fn infer_provider_type(model: &str) -> ProviderType {
333    let lower = model.to_lowercase();
334    if lower.contains("claude")
335        || lower.contains("opus")
336        || lower.contains("sonnet")
337        || lower.contains("haiku")
338    {
339        ProviderType::Anthropic
340    } else if lower.contains("gpt")
341        || lower.contains("o1")
342        || lower.contains("o3")
343        || lower.contains("o4")
344    {
345        ProviderType::OpenAI
346    } else {
347        // Default to Anthropic for unknown models
348        ProviderType::Anthropic
349    }
350}