Skip to main content

matrixcode_core/providers/
mod.rs

1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use tokio::sync::mpsc;
12
13use crate::constants::{ANTHROPIC_DEFAULT_BASE_URL, OPENAI_DEFAULT_BASE_URL};
14use crate::tools::ToolDefinition;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Message {
18    pub role: Role,
19    pub content: MessageContent,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25    System,
26    User,
27    Assistant,
28    Tool,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
32#[serde(untagged)]
33pub enum MessageContent {
34    Text(String),
35    Blocks(Vec<ContentBlock>),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
39#[serde(tag = "type")]
40pub enum ContentBlock {
41    #[serde(rename = "text")]
42    Text { text: String },
43    #[serde(rename = "tool_use")]
44    ToolUse {
45        id: String,
46        name: String,
47        input: serde_json::Value,
48    },
49    #[serde(rename = "tool_result")]
50    ToolResult {
51        tool_use_id: String,
52        content: String,
53    },
54    /// Anthropic extended-thinking block. `signature` is required when sending
55    /// the block back to the API in a follow-up turn.
56    #[serde(rename = "thinking")]
57    Thinking {
58        thinking: String,
59        #[serde(skip_serializing_if = "Option::is_none")]
60        signature: Option<String>,
61    },
62    /// Server-side tool use (e.g., web_search_tool). The server executes
63    /// the tool and returns results directly without client intervention.
64    #[serde(rename = "server_tool_use")]
65    ServerToolUse {
66        id: String,
67        name: String,
68        input: serde_json::Value,
69    },
70    /// Result from a server-side web search tool.
71    #[serde(rename = "web_search_tool_result")]
72    WebSearchResult {
73        tool_use_id: String,
74        content: WebSearchContent,
75    },
76}
77
78/// Content returned by the server-side web search tool.
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct WebSearchContent {
81    pub results: Vec<WebSearchResultItem>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85pub struct WebSearchResultItem {
86    pub title: Option<String>,
87    pub url: String,
88    pub encrypted_content: Option<String>,
89    pub snippet: Option<String>,
90}
91
92/// Server-side tool definition. These tools are executed by the API provider
93/// rather than by the client. Currently only web_search_tool is supported.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ServerTool {
96    #[serde(rename = "type")]
97    pub tool_type: String,
98    pub name: String,
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub max_uses: Option<u32>,
101}
102
103impl ServerTool {
104    /// Create a new web search server tool.
105    pub fn web_search(max_uses: Option<u32>) -> Self {
106        Self {
107            tool_type: "web_search_tool".to_string(),
108            name: "web_search".to_string(),
109            max_uses,
110        }
111    }
112}
113
114#[derive(Debug, Clone)]
115pub struct ChatRequest {
116    pub messages: Vec<Message>,
117    pub tools: Vec<ToolDefinition>,
118    pub system: Option<String>,
119    pub think: bool,
120    /// Maximum output tokens for the response.
121    pub max_tokens: u32,
122    /// Server-side tools that are executed by the API provider.
123    pub server_tools: Vec<ServerTool>,
124    /// Enable prompt caching for Anthropic provider.
125    pub enable_caching: bool,
126}
127
128#[derive(Debug, Clone)]
129pub struct ChatResponse {
130    pub content: Vec<ContentBlock>,
131    pub stop_reason: StopReason,
132    pub usage: Usage,
133}
134
135/// Token accounting for one provider turn. `input_tokens` already includes
136/// cached/uncached portions combined — when providers expose cache details
137/// separately we capture them here so callers can report cache effectiveness.
138#[derive(Debug, Clone, Default, PartialEq, Eq)]
139pub struct Usage {
140    pub input_tokens: u32,
141    pub output_tokens: u32,
142    pub cache_creation_input_tokens: u32,
143    pub cache_read_input_tokens: u32,
144}
145
146#[derive(Debug, Clone, PartialEq)]
147pub enum StopReason {
148    EndTurn,
149    ToolUse,
150    MaxTokens,
151}
152
153/// Incremental events emitted during a streaming chat turn.
154#[derive(Debug, Clone)]
155pub enum StreamEvent {
156    /// First byte received from the server — agent should stop any waiting spinner.
157    FirstByte,
158    /// Extended-thinking text delta (Anthropic thinking block).
159    ThinkingDelta(String),
160    /// Visible assistant text delta.
161    TextDelta(String),
162    /// A new tool_use block started.
163    ToolUseStart { id: String, name: String },
164    /// Incremental progress for the current tool_use block's JSON input.
165    /// `bytes_so_far` is the total accumulated size of the partial JSON
166    /// received for this block — useful for driving progress indicators
167    /// while the model streams large arguments (e.g. a full file body).
168    ToolInputDelta { bytes_so_far: usize },
169    /// Real-time usage update (output tokens so far).
170    Usage { output_tokens: u32 },
171    /// Final turn result — includes the full assembled content blocks.
172    Done(ChatResponse),
173    /// Fatal error during streaming.
174    Error(String),
175}
176
177#[async_trait]
178pub trait Provider: Send + Sync {
179    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
180
181    /// Best-effort context window size (in tokens) for the configured model.
182    /// `None` if the provider cannot infer it; callers should treat that as
183    /// "don't render a fullness bar".
184    fn context_size(&self) -> Option<u32> {
185        None
186    }
187
188    /// Get the model name for this provider.
189    fn model_name(&self) -> &str {
190        "unknown"
191    }
192
193    /// Stream a chat turn. Default impl wraps `chat` and emits one `Done` event,
194    /// so providers without native streaming still work (no incremental thinking).
195    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
196        let (tx, rx) = mpsc::channel(32);
197        let response = self.chat(request).await?;
198        let _ = tx.send(StreamEvent::FirstByte).await;
199        for block in &response.content {
200            if let ContentBlock::Text { text } = block {
201                let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
202            }
203        }
204        let _ = tx.send(StreamEvent::Done(response)).await;
205        Ok(rx)
206    }
207
208    /// Clone the provider into a boxed type.
209    fn clone_box(&self) -> Box<dyn Provider>;
210
211    /// Clone the provider into an Arc type (preferred for tool system).
212    fn clone_arc(&self) -> Arc<dyn Provider>;
213}
214
215impl Clone for Box<dyn Provider> {
216    fn clone(&self) -> Self {
217        self.clone_box()
218    }
219}
220
221// ============================================================================
222// Provider Factory
223// ============================================================================
224
225/// Provider type enumeration for factory creation.
226#[derive(Debug, Clone, Copy, PartialEq, Eq)]
227pub enum ProviderType {
228    Anthropic,
229    OpenAI,
230}
231
232/// Create a provider instance based on type and configuration.
233/// This is the recommended way to obtain a Provider instance.
234pub fn create_provider(
235    provider_type: ProviderType,
236    api_key: String,
237    model: String,
238    base_url: Option<String>,
239) -> Result<Box<dyn Provider>> {
240    create_provider_with_headers(provider_type, api_key, model, base_url, None)
241}
242
243/// Create a provider with extra headers support.
244pub fn create_provider_with_headers(
245    provider_type: ProviderType,
246    api_key: String,
247    model: String,
248    base_url: Option<String>,
249    extra_headers: Option<std::collections::HashMap<String, String>>,
250) -> Result<Box<dyn Provider>> {
251    match provider_type {
252        ProviderType::Anthropic => {
253            let provider = anthropic::AnthropicProvider::with_headers(
254                api_key,
255                model,
256                base_url.unwrap_or_else(|| ANTHROPIC_DEFAULT_BASE_URL.to_string()),
257                extra_headers,
258            );
259            Ok(Box::new(provider))
260        }
261        ProviderType::OpenAI => {
262            let provider = openai::OpenAIProvider::with_headers(
263                api_key,
264                model,
265                base_url.unwrap_or_else(|| OPENAI_DEFAULT_BASE_URL.to_string()),
266                extra_headers,
267            );
268            Ok(Box::new(provider))
269        }
270    }
271}
272
273/// Create a minimal provider from environment variables (for background tasks).
274/// Uses global config for API key and base URL, suitable for non-blocking background operations.
275pub fn create_minimal_provider(model: &str) -> Box<dyn Provider> {
276    // Try to load .env first (background tasks may not have .env loaded)
277    let _ = dotenvy::dotenv();
278
279    // Get API key from env (try multiple env vars)
280    let api_key = std::env::var("API_KEY")
281        .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
282        .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
283        .or_else(|_| std::env::var("OPENAI_API_KEY"))
284        .unwrap_or_default();
285
286    // Get base URL from env
287    let base_url = std::env::var("BASE_URL")
288        .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
289        .ok();
290
291    // Infer provider type from model name
292    let provider_type = infer_provider_type(model);
293
294    // Create provider (ignore errors, return a default)
295    create_provider_with_headers(provider_type, api_key, model.to_string(), base_url, None)
296        .unwrap_or_else(|_| {
297            // Fallback: create a dummy provider that returns empty
298            // This won't actually work, but prevents crashes
299            panic!("Failed to create minimal provider for background task: no API key configured")
300        })
301}
302
303/// Infer provider type from model name.
304/// Returns Anthropic for Claude models, OpenAI for GPT models.
305pub fn infer_provider_type(model: &str) -> ProviderType {
306    let lower = model.to_lowercase();
307    if lower.contains("claude")
308        || lower.contains("opus")
309        || lower.contains("sonnet")
310        || lower.contains("haiku")
311    {
312        ProviderType::Anthropic
313    } else if lower.contains("gpt")
314        || lower.contains("o1")
315        || lower.contains("o3")
316        || lower.contains("o4")
317    {
318        ProviderType::OpenAI
319    } else {
320        // Default to Anthropic for unknown models
321        ProviderType::Anthropic
322    }
323}