Skip to main content

matrixcode_core/providers/
mod.rs

1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use tokio::sync::mpsc;
11
12use crate::tools::ToolDefinition;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Message {
16    pub role: Role,
17    pub content: MessageContent,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23    System,
24    User,
25    Assistant,
26    Tool,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(untagged)]
31pub enum MessageContent {
32    Text(String),
33    Blocks(Vec<ContentBlock>),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37#[serde(tag = "type")]
38pub enum ContentBlock {
39    #[serde(rename = "text")]
40    Text { text: String },
41    #[serde(rename = "tool_use")]
42    ToolUse {
43        id: String,
44        name: String,
45        input: serde_json::Value,
46    },
47    #[serde(rename = "tool_result")]
48    ToolResult {
49        tool_use_id: String,
50        content: String,
51    },
52    /// Anthropic extended-thinking block. `signature` is required when sending
53    /// the block back to the API in a follow-up turn.
54    #[serde(rename = "thinking")]
55    Thinking {
56        thinking: String,
57        #[serde(skip_serializing_if = "Option::is_none")]
58        signature: Option<String>,
59    },
60    /// Server-side tool use (e.g., web_search_tool). The server executes
61    /// the tool and returns results directly without client intervention.
62    #[serde(rename = "server_tool_use")]
63    ServerToolUse {
64        id: String,
65        name: String,
66        input: serde_json::Value,
67    },
68    /// Result from a server-side web search tool.
69    #[serde(rename = "web_search_tool_result")]
70    WebSearchResult {
71        tool_use_id: String,
72        content: WebSearchContent,
73    },
74}
75
76/// Content returned by the server-side web search tool.
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
78pub struct WebSearchContent {
79    pub results: Vec<WebSearchResultItem>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
83pub struct WebSearchResultItem {
84    pub title: Option<String>,
85    pub url: String,
86    pub encrypted_content: Option<String>,
87    pub snippet: Option<String>,
88}
89
90/// Server-side tool definition. These tools are executed by the API provider
91/// rather than by the client. Currently only web_search_tool is supported.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ServerTool {
94    #[serde(rename = "type")]
95    pub tool_type: String,
96    pub name: String,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub max_uses: Option<u32>,
99}
100
101impl ServerTool {
102    /// Create a new web search server tool.
103    pub fn web_search(max_uses: Option<u32>) -> Self {
104        Self {
105            tool_type: "web_search_tool".to_string(),
106            name: "web_search".to_string(),
107            max_uses,
108        }
109    }
110}
111
112#[derive(Debug, Clone)]
113pub struct ChatRequest {
114    pub messages: Vec<Message>,
115    pub tools: Vec<ToolDefinition>,
116    pub system: Option<String>,
117    pub think: bool,
118    /// Maximum output tokens for the response.
119    pub max_tokens: u32,
120    /// Server-side tools that are executed by the API provider.
121    pub server_tools: Vec<ServerTool>,
122    /// Enable prompt caching for Anthropic provider.
123    pub enable_caching: bool,
124}
125
126#[derive(Debug, Clone)]
127pub struct ChatResponse {
128    pub content: Vec<ContentBlock>,
129    pub stop_reason: StopReason,
130    pub usage: Usage,
131}
132
133/// Token accounting for one provider turn. `input_tokens` already includes
134/// cached/uncached portions combined — when providers expose cache details
135/// separately we capture them here so callers can report cache effectiveness.
136#[derive(Debug, Clone, Default, PartialEq, Eq)]
137pub struct Usage {
138    pub input_tokens: u32,
139    pub output_tokens: u32,
140    pub cache_creation_input_tokens: u32,
141    pub cache_read_input_tokens: u32,
142}
143
144#[derive(Debug, Clone, PartialEq)]
145pub enum StopReason {
146    EndTurn,
147    ToolUse,
148    MaxTokens,
149}
150
151/// Incremental events emitted during a streaming chat turn.
152#[derive(Debug, Clone)]
153pub enum StreamEvent {
154    /// First byte received from the server — agent should stop any waiting spinner.
155    FirstByte,
156    /// Extended-thinking text delta (Anthropic thinking block).
157    ThinkingDelta(String),
158    /// Visible assistant text delta.
159    TextDelta(String),
160    /// A new tool_use block started.
161    ToolUseStart { id: String, name: String },
162    /// Incremental progress for the current tool_use block's JSON input.
163    /// `bytes_so_far` is the total accumulated size of the partial JSON
164    /// received for this block — useful for driving progress indicators
165    /// while the model streams large arguments (e.g. a full file body).
166    ToolInputDelta { bytes_so_far: usize },
167    /// Real-time usage update (output tokens so far).
168    Usage { output_tokens: u32 },
169    /// Final turn result — includes the full assembled content blocks.
170    Done(ChatResponse),
171    /// Fatal error during streaming.
172    Error(String),
173}
174
175#[async_trait]
176pub trait Provider: Send + Sync {
177    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
178
179    /// Best-effort context window size (in tokens) for the configured model.
180    /// `None` if the provider cannot infer it; callers should treat that as
181    /// "don't render a fullness bar".
182    fn context_size(&self) -> Option<u32> {
183        None
184    }
185
186    /// Get the model name for this provider.
187    fn model_name(&self) -> &str {
188        "unknown"
189    }
190
191    /// Stream a chat turn. Default impl wraps `chat` and emits one `Done` event,
192    /// so providers without native streaming still work (no incremental thinking).
193    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
194        let (tx, rx) = mpsc::channel(32);
195        let response = self.chat(request).await?;
196        let _ = tx.send(StreamEvent::FirstByte).await;
197        for block in &response.content {
198            if let ContentBlock::Text { text } = block {
199                let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
200            }
201        }
202        let _ = tx.send(StreamEvent::Done(response)).await;
203        Ok(rx)
204    }
205
206    /// Clone the provider into a boxed type.
207    fn clone_box(&self) -> Box<dyn Provider>;
208}
209
210impl Clone for Box<dyn Provider> {
211    fn clone(&self) -> Self {
212        self.clone_box()
213    }
214}
215
216// ============================================================================
217// Provider Factory
218// ============================================================================
219
220/// Provider type enumeration for factory creation.
221#[derive(Debug, Clone, Copy, PartialEq, Eq)]
222pub enum ProviderType {
223    Anthropic,
224    OpenAI,
225}
226
227/// Create a provider instance based on type and configuration.
228/// This is the recommended way to obtain a Provider instance.
229pub fn create_provider(
230    provider_type: ProviderType,
231    api_key: String,
232    model: String,
233    base_url: Option<String>,
234) -> Result<Box<dyn Provider>> {
235    create_provider_with_headers(provider_type, api_key, model, base_url, None)
236}
237
238/// Create a provider with extra headers support.
239pub fn create_provider_with_headers(
240    provider_type: ProviderType,
241    api_key: String,
242    model: String,
243    base_url: Option<String>,
244    extra_headers: Option<std::collections::HashMap<String, String>>,
245) -> Result<Box<dyn Provider>> {
246    match provider_type {
247        ProviderType::Anthropic => {
248            let provider = anthropic::AnthropicProvider::with_headers(
249                api_key,
250                model,
251                base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
252                extra_headers,
253            );
254            Ok(Box::new(provider))
255        }
256        ProviderType::OpenAI => {
257            let provider = openai::OpenAIProvider::with_headers(
258                api_key,
259                model,
260                base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
261                extra_headers,
262            );
263            Ok(Box::new(provider))
264        }
265    }
266}
267
268/// Create a minimal provider from environment variables (for background tasks).
269/// Uses global config for API key and base URL, suitable for non-blocking background operations.
270pub fn create_minimal_provider(model: &str) -> Box<dyn Provider> {
271    // Try to load .env first (background tasks may not have .env loaded)
272    let _ = dotenvy::dotenv();
273
274    // Get API key from env (try multiple env vars)
275    let api_key = std::env::var("API_KEY")
276        .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
277        .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
278        .or_else(|_| std::env::var("OPENAI_API_KEY"))
279        .unwrap_or_default();
280
281    // Get base URL from env
282    let base_url = std::env::var("BASE_URL")
283        .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
284        .ok();
285
286    // Infer provider type from model name
287    let provider_type = infer_provider_type(model);
288
289    // Create provider (ignore errors, return a default)
290    create_provider_with_headers(provider_type, api_key, model.to_string(), base_url, None)
291        .unwrap_or_else(|_| {
292            // Fallback: create a dummy provider that returns empty
293            // This won't actually work, but prevents crashes
294            panic!("Failed to create minimal provider for background task: no API key configured")
295        })
296}
297
298/// Infer provider type from model name.
299/// Returns Anthropic for Claude models, OpenAI for GPT models.
300pub fn infer_provider_type(model: &str) -> ProviderType {
301    let lower = model.to_lowercase();
302    if lower.contains("claude")
303        || lower.contains("opus")
304        || lower.contains("sonnet")
305        || lower.contains("haiku")
306    {
307        ProviderType::Anthropic
308    } else if lower.contains("gpt")
309        || lower.contains("o1")
310        || lower.contains("o3")
311        || lower.contains("o4")
312    {
313        ProviderType::OpenAI
314    } else {
315        // Default to Anthropic for unknown models
316        ProviderType::Anthropic
317    }
318}