Skip to main content

matrixcode_core/providers/
mod.rs

1pub mod anthropic;
2pub mod openai;
3
4#[cfg(test)]
5mod tests;
6
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use tokio::sync::mpsc;
11
12use crate::tools::ToolDefinition;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Message {
16    pub role: Role,
17    pub content: MessageContent,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23    System,
24    User,
25    Assistant,
26    Tool,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30#[serde(untagged)]
31pub enum MessageContent {
32    Text(String),
33    Blocks(Vec<ContentBlock>),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37#[serde(tag = "type")]
38pub enum ContentBlock {
39    #[serde(rename = "text")]
40    Text { text: String },
41    #[serde(rename = "tool_use")]
42    ToolUse {
43        id: String,
44        name: String,
45        input: serde_json::Value,
46    },
47    #[serde(rename = "tool_result")]
48    ToolResult {
49        tool_use_id: String,
50        content: String,
51    },
52    /// Anthropic extended-thinking block. `signature` is required when sending
53    /// the block back to the API in a follow-up turn.
54    #[serde(rename = "thinking")]
55    Thinking {
56        thinking: String,
57        #[serde(skip_serializing_if = "Option::is_none")]
58        signature: Option<String>,
59    },
60    /// Server-side tool use (e.g., web_search_tool). The server executes
61    /// the tool and returns results directly without client intervention.
62    #[serde(rename = "server_tool_use")]
63    ServerToolUse {
64        id: String,
65        name: String,
66        input: serde_json::Value,
67    },
68    /// Result from a server-side web search tool.
69    #[serde(rename = "web_search_tool_result")]
70    WebSearchResult {
71        tool_use_id: String,
72        content: WebSearchContent,
73    },
74}
75
76/// Content returned by the server-side web search tool.
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
78pub struct WebSearchContent {
79    pub results: Vec<WebSearchResultItem>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
83pub struct WebSearchResultItem {
84    pub title: Option<String>,
85    pub url: String,
86    pub encrypted_content: Option<String>,
87    pub snippet: Option<String>,
88}
89
90/// Server-side tool definition. These tools are executed by the API provider
91/// rather than by the client. Currently only web_search_tool is supported.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ServerTool {
94    #[serde(rename = "type")]
95    pub tool_type: String,
96    pub name: String,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub max_uses: Option<u32>,
99}
100
101impl ServerTool {
102    /// Create a new web search server tool.
103    pub fn web_search(max_uses: Option<u32>) -> Self {
104        Self {
105            tool_type: "web_search_tool".to_string(),
106            name: "web_search".to_string(),
107            max_uses,
108        }
109    }
110}
111
112#[derive(Debug, Clone)]
113pub struct ChatRequest {
114    pub messages: Vec<Message>,
115    pub tools: Vec<ToolDefinition>,
116    pub system: Option<String>,
117    pub think: bool,
118    /// Maximum output tokens for the response.
119    pub max_tokens: u32,
120    /// Server-side tools that are executed by the API provider.
121    pub server_tools: Vec<ServerTool>,
122    /// Enable prompt caching for Anthropic provider.
123    pub enable_caching: bool,
124}
125
126#[derive(Debug, Clone)]
127pub struct ChatResponse {
128    pub content: Vec<ContentBlock>,
129    pub stop_reason: StopReason,
130    pub usage: Usage,
131}
132
133/// Token accounting for one provider turn. `input_tokens` already includes
134/// cached/uncached portions combined — when providers expose cache details
135/// separately we capture them here so callers can report cache effectiveness.
136#[derive(Debug, Clone, Default, PartialEq, Eq)]
137pub struct Usage {
138    pub input_tokens: u32,
139    pub output_tokens: u32,
140    pub cache_creation_input_tokens: u32,
141    pub cache_read_input_tokens: u32,
142}
143
144#[derive(Debug, Clone, PartialEq)]
145pub enum StopReason {
146    EndTurn,
147    ToolUse,
148    MaxTokens,
149}
150
151/// Incremental events emitted during a streaming chat turn.
152#[derive(Debug, Clone)]
153pub enum StreamEvent {
154    /// First byte received from the server — agent should stop any waiting spinner.
155    FirstByte,
156    /// Extended-thinking text delta (Anthropic thinking block).
157    ThinkingDelta(String),
158    /// Visible assistant text delta.
159    TextDelta(String),
160    /// A new tool_use block started.
161    ToolUseStart { id: String, name: String },
162    /// Incremental progress for the current tool_use block's JSON input.
163    /// `bytes_so_far` is the total accumulated size of the partial JSON
164    /// received for this block — useful for driving progress indicators
165    /// while the model streams large arguments (e.g. a full file body).
166    ToolInputDelta { bytes_so_far: usize },
167    /// Real-time usage update (output tokens so far).
168    Usage { output_tokens: u32 },
169    /// Final turn result — includes the full assembled content blocks.
170    Done(ChatResponse),
171    /// Fatal error during streaming.
172    Error(String),
173}
174
175#[async_trait]
176pub trait Provider: Send + Sync {
177    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
178
179    /// Best-effort context window size (in tokens) for the configured model.
180    /// `None` if the provider cannot infer it; callers should treat that as
181    /// "don't render a fullness bar".
182    fn context_size(&self) -> Option<u32> {
183        None
184    }
185
186    /// Stream a chat turn. Default impl wraps `chat` and emits one `Done` event,
187    /// so providers without native streaming still work (no incremental thinking).
188    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
189        let (tx, rx) = mpsc::channel(32);
190        let response = self.chat(request).await?;
191        let _ = tx.send(StreamEvent::FirstByte).await;
192        for block in &response.content {
193            if let ContentBlock::Text { text } = block {
194                let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
195            }
196        }
197        let _ = tx.send(StreamEvent::Done(response)).await;
198        Ok(rx)
199    }
200
201    /// Clone the provider into a boxed type.
202    fn clone_box(&self) -> Box<dyn Provider>;
203}
204
205impl Clone for Box<dyn Provider> {
206    fn clone(&self) -> Self {
207        self.clone_box()
208    }
209}
210
211// ============================================================================
212// Provider Factory
213// ============================================================================
214
215/// Provider type enumeration for factory creation.
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
217pub enum ProviderType {
218    Anthropic,
219    OpenAI,
220}
221
222/// Create a provider instance based on type and configuration.
223/// This is the recommended way to obtain a Provider instance.
224pub fn create_provider(
225    provider_type: ProviderType,
226    api_key: String,
227    model: String,
228    base_url: Option<String>,
229) -> Result<Box<dyn Provider>> {
230    create_provider_with_headers(provider_type, api_key, model, base_url, None)
231}
232
233/// Create a provider with extra headers support.
234pub fn create_provider_with_headers(
235    provider_type: ProviderType,
236    api_key: String,
237    model: String,
238    base_url: Option<String>,
239    extra_headers: Option<std::collections::HashMap<String, String>>,
240) -> Result<Box<dyn Provider>> {
241    match provider_type {
242        ProviderType::Anthropic => {
243            let provider = anthropic::AnthropicProvider::with_headers(
244                api_key,
245                model,
246                base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
247                extra_headers,
248            );
249            Ok(Box::new(provider))
250        }
251        ProviderType::OpenAI => {
252            let provider = openai::OpenAIProvider::with_headers(
253                api_key,
254                model,
255                base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
256                extra_headers,
257            );
258            Ok(Box::new(provider))
259        }
260    }
261}
262
263/// Infer provider type from model name.
264/// Returns Anthropic for Claude models, OpenAI for GPT models.
265pub fn infer_provider_type(model: &str) -> ProviderType {
266    let lower = model.to_lowercase();
267    if lower.contains("claude")
268        || lower.contains("opus")
269        || lower.contains("sonnet")
270        || lower.contains("haiku")
271    {
272        ProviderType::Anthropic
273    } else if lower.contains("gpt")
274        || lower.contains("o1")
275        || lower.contains("o3")
276        || lower.contains("o4")
277    {
278        ProviderType::OpenAI
279    } else {
280        // Default to Anthropic for unknown models
281        ProviderType::Anthropic
282    }
283}