Skip to main content

matrixcode_core/providers/
mod.rs

1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::mpsc;
12
13use crate::constants::{ANTHROPIC_DEFAULT_BASE_URL, OPENAI_DEFAULT_BASE_URL};
14use crate::tools::ToolDefinition;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub struct Message {
18    pub role: Role,
19    pub content: MessageContent,
20}
21
22#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25    System,
26    User,
27    Assistant,
28    Tool,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32#[serde(untagged)]
33pub enum MessageContent {
34    Text(String),
35    Blocks(Vec<ContentBlock>),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(tag = "type")]
40pub enum ContentBlock {
41    #[serde(rename = "text")]
42    Text { text: String },
43    #[serde(rename = "tool_use")]
44    ToolUse {
45        id: String,
46        name: String,
47        input: serde_json::Value,
48    },
49    #[serde(rename = "tool_result")]
50    ToolResult {
51        tool_use_id: String,
52        content: String,
53    },
54    /// Anthropic extended-thinking block. `signature` is required when sending
55    /// the block back to the API in a follow-up turn.
56    #[serde(rename = "thinking")]
57    Thinking {
58        thinking: String,
59        #[serde(skip_serializing_if = "Option::is_none")]
60        signature: Option<String>,
61    },
62    /// Server-side tool use (e.g., web_search_tool). The server executes
63    /// the tool and returns results directly without client intervention.
64    #[serde(rename = "server_tool_use")]
65    ServerToolUse {
66        id: String,
67        name: String,
68        input: serde_json::Value,
69    },
70    #[serde(rename = "server_tool_result")]
71    ServerToolResult {
72        tool_use_id: String,
73        content: String,
74    },
75    /// Result from a server-side web search tool.
76    #[serde(rename = "web_search_tool_result")]
77    WebSearchResult {
78        tool_use_id: String,
79        content: WebSearchContent,
80    },
81}
82
83// Manual PartialEq implementation for ContentBlock (serde_json::Value doesn't derive PartialEq)
84impl PartialEq for ContentBlock {
85    fn eq(&self, other: &Self) -> bool {
86        match (self, other) {
87            (ContentBlock::Text { text: a }, ContentBlock::Text { text: b }) => a == b,
88            (ContentBlock::ToolUse { id: a, name: b, .. }, ContentBlock::ToolUse { id: c, name: d, .. }) => {
89                a == c && b == d
90            },
91            (ContentBlock::ToolResult { tool_use_id: a, content: b }, ContentBlock::ToolResult { tool_use_id: c, content: d }) => {
92                a == c && b == d
93            },
94            (ContentBlock::Thinking { thinking: a, signature: b }, ContentBlock::Thinking { thinking: c, signature: d }) => {
95                a == c && b == d
96            },
97            (ContentBlock::ServerToolUse { id: a, name: b, .. }, ContentBlock::ServerToolUse { id: c, name: d, .. }) => {
98                a == c && b == d
99            },
100            (ContentBlock::ServerToolResult { tool_use_id: a, content: b }, ContentBlock::ServerToolResult { tool_use_id: c, content: d }) => {
101                a == c && b == d
102            },
103            (ContentBlock::WebSearchResult { tool_use_id: a, .. }, ContentBlock::WebSearchResult { tool_use_id: b, .. }) => {
104                a == b  // Compare by tool_use_id only for web search results
105            },
106            _ => false,
107        }
108    }
109}
110
111/// Content returned by the server-side web search tool.
112#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
113pub struct WebSearchContent {
114    pub results: Vec<WebSearchResultItem>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
118pub struct WebSearchResultItem {
119    pub title: Option<String>,
120    pub url: String,
121    pub encrypted_content: Option<String>,
122    pub snippet: Option<String>,
123}
124
125/// Server-side tool definition. These tools are executed by the API provider
126/// rather than by the client. Currently only web_search_tool is supported.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ServerTool {
129    #[serde(rename = "type")]
130    pub tool_type: String,
131    pub name: String,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub max_uses: Option<u32>,
134}
135
136impl ServerTool {
137    /// Create a new web search server tool.
138    pub fn web_search(max_uses: Option<u32>) -> Self {
139        Self {
140            tool_type: "web_search_tool".to_string(),
141            name: "web_search".to_string(),
142            max_uses,
143        }
144    }
145}
146
147#[derive(Debug, Clone)]
148pub struct ChatRequest {
149    pub messages: Vec<Message>,
150    pub tools: Vec<ToolDefinition>,
151    pub system: Option<String>,
152    pub think: bool,
153    /// Maximum output tokens for the response.
154    pub max_tokens: u32,
155    /// Server-side tools that are executed by the API provider.
156    pub server_tools: Vec<ServerTool>,
157    /// Enable prompt caching for Anthropic provider.
158    pub enable_caching: bool,
159}
160
161#[derive(Debug, Clone)]
162pub struct ChatResponse {
163    pub content: Vec<ContentBlock>,
164    pub stop_reason: StopReason,
165    pub usage: Usage,
166}
167
168/// Token accounting for one provider turn. `input_tokens` already includes
169/// cached/uncached portions combined — when providers expose cache details
170/// separately we capture them here so callers can report cache effectiveness.
171#[derive(Debug, Clone, Default, PartialEq, Eq)]
172pub struct Usage {
173    pub input_tokens: u32,
174    pub output_tokens: u32,
175    pub cache_creation_input_tokens: u32,
176    pub cache_read_input_tokens: u32,
177}
178
179#[derive(Debug, Clone, PartialEq)]
180pub enum StopReason {
181    EndTurn,
182    ToolUse,
183    MaxTokens,
184}
185
186/// Incremental events emitted during a streaming chat turn.
187#[derive(Debug, Clone)]
188pub enum StreamEvent {
189    /// First byte received from the server — agent should stop any waiting spinner.
190    FirstByte,
191    /// Extended-thinking text delta (Anthropic thinking block).
192    ThinkingDelta(String),
193    /// Visible assistant text delta.
194    TextDelta(String),
195    /// A new tool_use block started.
196    ToolUseStart { id: String, name: String },
197    /// Incremental progress for the current tool_use block's JSON input.
198    /// `bytes_so_far` is the total accumulated size of the partial JSON
199    /// received for this block — useful for driving progress indicators
200    /// while the model streams large arguments (e.g. a full file body).
201    ToolInputDelta { bytes_so_far: usize },
202    /// Complete tool input assembled before the final Done event.
203    ToolInputComplete {
204        id: String,
205        name: String,
206        input: serde_json::Value,
207    },
208    /// Real-time usage update (output tokens so far).
209    Usage { output_tokens: u32 },
210    /// Final turn result — includes the full assembled content blocks.
211    Done(ChatResponse),
212    /// Fatal error during streaming.
213    Error(String),
214}
215
216#[async_trait]
217pub trait Provider: Send + Sync {
218    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
219
220    /// Best-effort context window size (in tokens) for the configured model.
221    /// `None` if the provider cannot infer it; callers should treat that as
222    /// "don't render a fullness bar".
223    fn context_size(&self) -> Option<u32> {
224        None
225    }
226
227    /// Get the model name for this provider.
228    fn model_name(&self) -> &str {
229        "unknown"
230    }
231
232    /// Stream a chat turn. Default impl wraps `chat` and emits one `Done` event,
233    /// so providers without native streaming still work (no incremental thinking).
234    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
235        let (tx, rx) = mpsc::channel(32);
236        let response = self.chat(request).await?;
237        let _ = tx.send(StreamEvent::FirstByte).await;
238        for block in &response.content {
239            match block {
240                ContentBlock::Thinking { thinking, .. } => {
241                    let _ = tx.send(StreamEvent::ThinkingDelta(thinking.clone())).await;
242                }
243                ContentBlock::Text { text } => {
244                    let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
245                }
246                ContentBlock::ToolUse { id, name, input } => {
247                    let _ = tx
248                        .send(StreamEvent::ToolUseStart {
249                            id: id.clone(),
250                            name: name.clone(),
251                        })
252                        .await;
253                    let _ = tx
254                        .send(StreamEvent::ToolInputComplete {
255                            id: id.clone(),
256                            name: name.clone(),
257                            input: input.clone(),
258                        })
259                        .await;
260                }
261                _ => {}
262            }
263        }
264        let _ = tx.send(StreamEvent::Done(response)).await;
265        Ok(rx)
266    }
267
268    /// Clone the provider into a boxed type.
269    fn clone_box(&self) -> Box<dyn Provider>;
270
271    /// Clone the provider into an Arc type (preferred for tool system).
272    fn clone_arc(&self) -> Arc<dyn Provider>;
273}
274
275impl Clone for Box<dyn Provider> {
276    fn clone(&self) -> Self {
277        self.clone_box()
278    }
279}
280
281// ============================================================================
282// Provider Factory
283// ============================================================================
284
285/// Provider type enumeration for factory creation.
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub enum ProviderType {
288    Anthropic,
289    OpenAI,
290}
291
292/// Create a provider instance based on type and configuration.
293/// This is the recommended way to obtain a Provider instance.
294pub fn create_provider(
295    provider_type: ProviderType,
296    api_key: String,
297    model: String,
298    base_url: Option<String>,
299) -> Result<Box<dyn Provider>> {
300    create_provider_with_headers(provider_type, api_key, model, base_url, None)
301}
302
303/// Create a provider with extra headers support.
304pub fn create_provider_with_headers(
305    provider_type: ProviderType,
306    api_key: String,
307    model: String,
308    base_url: Option<String>,
309    extra_headers: Option<std::collections::HashMap<String, String>>,
310) -> Result<Box<dyn Provider>> {
311    match provider_type {
312        ProviderType::Anthropic => {
313            let provider = anthropic::AnthropicProvider::with_headers(
314                api_key,
315                model,
316                base_url.unwrap_or_else(|| ANTHROPIC_DEFAULT_BASE_URL.to_string()),
317                extra_headers,
318            );
319            Ok(Box::new(provider))
320        }
321        ProviderType::OpenAI => {
322            let provider = openai::OpenAIProvider::with_headers(
323                api_key,
324                model,
325                base_url.unwrap_or_else(|| OPENAI_DEFAULT_BASE_URL.to_string()),
326                extra_headers,
327            );
328            Ok(Box::new(provider))
329        }
330    }
331}
332
333/// Create a minimal provider from environment variables (for background tasks).
334/// Uses global config for API key and base URL, suitable for non-blocking background operations.
335pub fn create_minimal_provider(model: &str) -> Box<dyn Provider> {
336    // Try to load .env first (background tasks may not have .env loaded)
337    let _ = dotenvy::dotenv();
338
339    // Get API key from env (try multiple env vars)
340    let api_key = std::env::var("API_KEY")
341        .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
342        .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
343        .or_else(|_| std::env::var("OPENAI_API_KEY"))
344        .unwrap_or_default();
345
346    // Get base URL from env
347    let base_url = std::env::var("BASE_URL")
348        .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
349        .ok();
350
351    // Infer provider type from model name
352    let provider_type = infer_provider_type(model);
353
354    // Create provider (ignore errors, return a default)
355    create_provider_with_headers(provider_type, api_key, model.to_string(), base_url, None)
356        .unwrap_or_else(|_| {
357            // Fallback: create a dummy provider that returns empty
358            // This won't actually work, but prevents crashes
359            panic!("Failed to create minimal provider for background task: no API key configured")
360        })
361}
362
363/// Infer provider type from model name.
364/// Returns Anthropic for Claude models, OpenAI for GPT models.
365pub fn infer_provider_type(model: &str) -> ProviderType {
366    let lower = model.to_lowercase();
367    if lower.contains("claude")
368        || lower.contains("opus")
369        || lower.contains("sonnet")
370        || lower.contains("haiku")
371    {
372        ProviderType::Anthropic
373    } else if lower.contains("gpt")
374        || lower.contains("o1")
375        || lower.contains("o3")
376        || lower.contains("o4")
377    {
378        ProviderType::OpenAI
379    } else {
380        // Default to Anthropic for unknown models
381        ProviderType::Anthropic
382    }
383}