Skip to main content

matrixcode_core/providers/
mod.rs

1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::mpsc;
12
13use crate::constants::{ANTHROPIC_DEFAULT_BASE_URL, OPENAI_DEFAULT_BASE_URL};
14use crate::tools::ToolDefinition;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18    pub role: Role,
19    pub content: MessageContent,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25    System,
26    User,
27    Assistant,
28    Tool,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32#[serde(untagged)]
33pub enum MessageContent {
34    Text(String),
35    Blocks(Vec<ContentBlock>),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39#[serde(tag = "type")]
40pub enum ContentBlock {
41    #[serde(rename = "text")]
42    Text { text: String },
43    #[serde(rename = "tool_use")]
44    ToolUse {
45        id: String,
46        name: String,
47        input: serde_json::Value,
48    },
49    #[serde(rename = "tool_result")]
50    ToolResult {
51        tool_use_id: String,
52        content: String,
53    },
54    /// Anthropic extended-thinking block. `signature` is required when sending
55    /// the block back to the API in a follow-up turn.
56    #[serde(rename = "thinking")]
57    Thinking {
58        thinking: String,
59        #[serde(skip_serializing_if = "Option::is_none")]
60        signature: Option<String>,
61    },
62    /// Server-side tool use (e.g., web_search_tool). The server executes
63    /// the tool and returns results directly without client intervention.
64    #[serde(rename = "server_tool_use")]
65    ServerToolUse {
66        id: String,
67        name: String,
68        input: serde_json::Value,
69    },
70    /// Result from a server-side web search tool.
71    #[serde(rename = "web_search_tool_result")]
72    WebSearchResult {
73        tool_use_id: String,
74        content: WebSearchContent,
75    },
76}
77
78/// Content returned by the server-side web search tool.
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct WebSearchContent {
81    pub results: Vec<WebSearchResultItem>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85pub struct WebSearchResultItem {
86    pub title: Option<String>,
87    pub url: String,
88    pub encrypted_content: Option<String>,
89    pub snippet: Option<String>,
90}
91
92/// Server-side tool definition. These tools are executed by the API provider
93/// rather than by the client. Currently only web_search_tool is supported.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ServerTool {
96    #[serde(rename = "type")]
97    pub tool_type: String,
98    pub name: String,
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub max_uses: Option<u32>,
101}
102
103impl ServerTool {
104    /// Create a new web search server tool.
105    pub fn web_search(max_uses: Option<u32>) -> Self {
106        Self {
107            tool_type: "web_search_tool".to_string(),
108            name: "web_search".to_string(),
109            max_uses,
110        }
111    }
112}
113
114#[derive(Debug, Clone)]
115pub struct ChatRequest {
116    pub messages: Vec<Message>,
117    pub tools: Vec<ToolDefinition>,
118    pub system: Option<String>,
119    pub think: bool,
120    /// Maximum output tokens for the response.
121    pub max_tokens: u32,
122    /// Server-side tools that are executed by the API provider.
123    pub server_tools: Vec<ServerTool>,
124    /// Enable prompt caching for Anthropic provider.
125    pub enable_caching: bool,
126}
127
128#[derive(Debug, Clone)]
129pub struct ChatResponse {
130    pub content: Vec<ContentBlock>,
131    pub stop_reason: StopReason,
132    pub usage: Usage,
133}
134
135/// Token accounting for one provider turn. `input_tokens` already includes
136/// cached/uncached portions combined — when providers expose cache details
137/// separately we capture them here so callers can report cache effectiveness.
138#[derive(Debug, Clone, Default, PartialEq, Eq)]
139pub struct Usage {
140    pub input_tokens: u32,
141    pub output_tokens: u32,
142    pub cache_creation_input_tokens: u32,
143    pub cache_read_input_tokens: u32,
144}
145
146#[derive(Debug, Clone, PartialEq)]
147pub enum StopReason {
148    EndTurn,
149    ToolUse,
150    MaxTokens,
151}
152
153/// Incremental events emitted during a streaming chat turn.
154#[derive(Debug, Clone)]
155pub enum StreamEvent {
156    /// First byte received from the server — agent should stop any waiting spinner.
157    FirstByte,
158    /// Extended-thinking text delta (Anthropic thinking block).
159    ThinkingDelta(String),
160    /// Visible assistant text delta.
161    TextDelta(String),
162    /// A new tool_use block started.
163    ToolUseStart { id: String, name: String },
164    /// Incremental progress for the current tool_use block's JSON input.
165    /// `bytes_so_far` is the total accumulated size of the partial JSON
166    /// received for this block — useful for driving progress indicators
167    /// while the model streams large arguments (e.g. a full file body).
168    ToolInputDelta { bytes_so_far: usize },
169    /// Real-time usage update (output tokens so far).
170    Usage { output_tokens: u32 },
171    /// Final turn result — includes the full assembled content blocks.
172    Done(ChatResponse),
173    /// Fatal error during streaming.
174    Error(String),
175}
176
177#[async_trait]
178pub trait Provider: Send + Sync {
179    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
180
181    /// Best-effort context window size (in tokens) for the configured model.
182    /// `None` if the provider cannot infer it; callers should treat that as
183    /// "don't render a fullness bar".
184    fn context_size(&self) -> Option<u32> {
185        None
186    }
187
188    /// Get the model name for this provider.
189    fn model_name(&self) -> &str {
190        "unknown"
191    }
192
193    /// Stream a chat turn. Default impl wraps `chat` and emits one `Done` event,
194    /// so providers without native streaming still work (no incremental thinking).
195    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
196        let (tx, rx) = mpsc::channel(32);
197        let response = self.chat(request).await?;
198        let _ = tx.send(StreamEvent::FirstByte).await;
199        for block in &response.content {
200            match block {
201                ContentBlock::Thinking { thinking, .. } => {
202                    let _ = tx.send(StreamEvent::ThinkingDelta(thinking.clone())).await;
203                }
204                ContentBlock::Text { text } => {
205                    let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
206                }
207                _ => {}
208            }
209        }
210        let _ = tx.send(StreamEvent::Done(response)).await;
211        Ok(rx)
212    }
213
214    /// Clone the provider into a boxed type.
215    fn clone_box(&self) -> Box<dyn Provider>;
216
217    /// Clone the provider into an Arc type (preferred for tool system).
218    fn clone_arc(&self) -> Arc<dyn Provider>;
219}
220
221impl Clone for Box<dyn Provider> {
222    fn clone(&self) -> Self {
223        self.clone_box()
224    }
225}
226
227// ============================================================================
228// Provider Factory
229// ============================================================================
230
231/// Provider type enumeration for factory creation.
232#[derive(Debug, Clone, Copy, PartialEq, Eq)]
233pub enum ProviderType {
234    Anthropic,
235    OpenAI,
236}
237
238/// Create a provider instance based on type and configuration.
239/// This is the recommended way to obtain a Provider instance.
240pub fn create_provider(
241    provider_type: ProviderType,
242    api_key: String,
243    model: String,
244    base_url: Option<String>,
245) -> Result<Box<dyn Provider>> {
246    create_provider_with_headers(provider_type, api_key, model, base_url, None)
247}
248
249/// Create a provider with extra headers support.
250pub fn create_provider_with_headers(
251    provider_type: ProviderType,
252    api_key: String,
253    model: String,
254    base_url: Option<String>,
255    extra_headers: Option<std::collections::HashMap<String, String>>,
256) -> Result<Box<dyn Provider>> {
257    match provider_type {
258        ProviderType::Anthropic => {
259            let provider = anthropic::AnthropicProvider::with_headers(
260                api_key,
261                model,
262                base_url.unwrap_or_else(|| ANTHROPIC_DEFAULT_BASE_URL.to_string()),
263                extra_headers,
264            );
265            Ok(Box::new(provider))
266        }
267        ProviderType::OpenAI => {
268            let provider = openai::OpenAIProvider::with_headers(
269                api_key,
270                model,
271                base_url.unwrap_or_else(|| OPENAI_DEFAULT_BASE_URL.to_string()),
272                extra_headers,
273            );
274            Ok(Box::new(provider))
275        }
276    }
277}
278
279/// Create a minimal provider from environment variables (for background tasks).
280/// Uses global config for API key and base URL, suitable for non-blocking background operations.
281pub fn create_minimal_provider(model: &str) -> Box<dyn Provider> {
282    // Try to load .env first (background tasks may not have .env loaded)
283    let _ = dotenvy::dotenv();
284
285    // Get API key from env (try multiple env vars)
286    let api_key = std::env::var("API_KEY")
287        .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
288        .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
289        .or_else(|_| std::env::var("OPENAI_API_KEY"))
290        .unwrap_or_default();
291
292    // Get base URL from env
293    let base_url = std::env::var("BASE_URL")
294        .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
295        .ok();
296
297    // Infer provider type from model name
298    let provider_type = infer_provider_type(model);
299
300    // Create provider (ignore errors, return a default)
301    create_provider_with_headers(provider_type, api_key, model.to_string(), base_url, None)
302        .unwrap_or_else(|_| {
303            // Fallback: create a dummy provider that returns empty
304            // This won't actually work, but prevents crashes
305            panic!("Failed to create minimal provider for background task: no API key configured")
306        })
307}
308
309/// Infer provider type from model name.
310/// Returns Anthropic for Claude models, OpenAI for GPT models.
311pub fn infer_provider_type(model: &str) -> ProviderType {
312    let lower = model.to_lowercase();
313    if lower.contains("claude")
314        || lower.contains("opus")
315        || lower.contains("sonnet")
316        || lower.contains("haiku")
317    {
318        ProviderType::Anthropic
319    } else if lower.contains("gpt")
320        || lower.contains("o1")
321        || lower.contains("o3")
322        || lower.contains("o4")
323    {
324        ProviderType::OpenAI
325    } else {
326        // Default to Anthropic for unknown models
327        ProviderType::Anthropic
328    }
329}