Skip to main content

matrixcode_core/providers/
mod.rs

1pub mod anthropic;
2pub mod openai;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use tokio::sync::mpsc;
8
9use crate::tools::ToolDefinition;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Message {
13    pub role: Role,
14    pub content: MessageContent,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(rename_all = "lowercase")]
19pub enum Role {
20    System,
21    User,
22    Assistant,
23    Tool,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27#[serde(untagged)]
28pub enum MessageContent {
29    Text(String),
30    Blocks(Vec<ContentBlock>),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34#[serde(tag = "type")]
35pub enum ContentBlock {
36    #[serde(rename = "text")]
37    Text { text: String },
38    #[serde(rename = "tool_use")]
39    ToolUse {
40        id: String,
41        name: String,
42        input: serde_json::Value,
43    },
44    #[serde(rename = "tool_result")]
45    ToolResult {
46        tool_use_id: String,
47        content: String,
48    },
49    /// Anthropic extended-thinking block. `signature` is required when sending
50    /// the block back to the API in a follow-up turn.
51    #[serde(rename = "thinking")]
52    Thinking {
53        thinking: String,
54        #[serde(skip_serializing_if = "Option::is_none")]
55        signature: Option<String>,
56    },
57    /// Server-side tool use (e.g., web_search_tool). The server executes
58    /// the tool and returns results directly without client intervention.
59    #[serde(rename = "server_tool_use")]
60    ServerToolUse {
61        id: String,
62        name: String,
63        input: serde_json::Value,
64    },
65    /// Result from a server-side web search tool.
66    #[serde(rename = "web_search_tool_result")]
67    WebSearchResult {
68        tool_use_id: String,
69        content: WebSearchContent,
70    },
71}
72
73/// Content returned by the server-side web search tool.
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
75pub struct WebSearchContent {
76    pub results: Vec<WebSearchResultItem>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct WebSearchResultItem {
81    pub title: Option<String>,
82    pub url: String,
83    pub encrypted_content: Option<String>,
84    pub snippet: Option<String>,
85}
86
87/// Server-side tool definition. These tools are executed by the API provider
88/// rather than by the client. Currently only web_search_tool is supported.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ServerTool {
91    #[serde(rename = "type")]
92    pub tool_type: String,
93    pub name: String,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub max_uses: Option<u32>,
96}
97
98impl ServerTool {
99    /// Create a new web search server tool.
100    pub fn web_search(max_uses: Option<u32>) -> Self {
101        Self {
102            tool_type: "web_search_tool".to_string(),
103            name: "web_search".to_string(),
104            max_uses,
105        }
106    }
107}
108
109#[derive(Debug, Clone)]
110pub struct ChatRequest {
111    pub messages: Vec<Message>,
112    pub tools: Vec<ToolDefinition>,
113    pub system: Option<String>,
114    pub think: bool,
115    /// Maximum output tokens for the response.
116    pub max_tokens: u32,
117    /// Server-side tools that are executed by the API provider.
118    pub server_tools: Vec<ServerTool>,
119    /// Enable prompt caching for Anthropic provider.
120    pub enable_caching: bool,
121}
122
123#[derive(Debug, Clone)]
124pub struct ChatResponse {
125    pub content: Vec<ContentBlock>,
126    pub stop_reason: StopReason,
127    pub usage: Usage,
128}
129
130/// Token accounting for one provider turn. `input_tokens` already includes
131/// cached/uncached portions combined — when providers expose cache details
132/// separately we capture them here so callers can report cache effectiveness.
133#[derive(Debug, Clone, Default, PartialEq, Eq)]
134pub struct Usage {
135    pub input_tokens: u32,
136    pub output_tokens: u32,
137    pub cache_creation_input_tokens: u32,
138    pub cache_read_input_tokens: u32,
139}
140
141#[derive(Debug, Clone, PartialEq)]
142pub enum StopReason {
143    EndTurn,
144    ToolUse,
145    MaxTokens,
146}
147
148/// Incremental events emitted during a streaming chat turn.
149#[derive(Debug, Clone)]
150pub enum StreamEvent {
151    /// First byte received from the server — agent should stop any waiting spinner.
152    FirstByte,
153    /// Extended-thinking text delta (Anthropic thinking block).
154    ThinkingDelta(String),
155    /// Visible assistant text delta.
156    TextDelta(String),
157    /// A new tool_use block started.
158    ToolUseStart { id: String, name: String },
159    /// Incremental progress for the current tool_use block's JSON input.
160    /// `bytes_so_far` is the total accumulated size of the partial JSON
161    /// received for this block — useful for driving progress indicators
162    /// while the model streams large arguments (e.g. a full file body).
163    ToolInputDelta { bytes_so_far: usize },
164    /// Real-time usage update (output tokens so far).
165    Usage { output_tokens: u32 },
166    /// Final turn result — includes the full assembled content blocks.
167    Done(ChatResponse),
168    /// Fatal error during streaming.
169    Error(String),
170}
171
172#[async_trait]
173pub trait Provider: Send + Sync {
174    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
175
176    /// Best-effort context window size (in tokens) for the configured model.
177    /// `None` if the provider cannot infer it; callers should treat that as
178    /// "don't render a fullness bar".
179    fn context_size(&self) -> Option<u32> {
180        None
181    }
182
183    /// Stream a chat turn. Default impl wraps `chat` and emits one `Done` event,
184    /// so providers without native streaming still work (no incremental thinking).
185    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
186        let (tx, rx) = mpsc::channel(32);
187        let response = self.chat(request).await?;
188        let _ = tx.send(StreamEvent::FirstByte).await;
189        for block in &response.content {
190            if let ContentBlock::Text { text } = block {
191                let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
192            }
193        }
194        let _ = tx.send(StreamEvent::Done(response)).await;
195        Ok(rx)
196    }
197
198    /// Clone the provider into a boxed type.
199    fn clone_box(&self) -> Box<dyn Provider>;
200}
201
202impl Clone for Box<dyn Provider> {
203    fn clone(&self) -> Self {
204        self.clone_box()
205    }
206}
207
208// ============================================================================
209// Provider Factory
210// ============================================================================
211
212/// Provider type enumeration for factory creation.
213#[derive(Debug, Clone, PartialEq, Eq)]
214pub enum ProviderType {
215    Anthropic,
216    OpenAI,
217}
218
219/// Create a provider instance based on type and configuration.
220/// This is the recommended way to obtain a Provider instance.
221pub fn create_provider(
222    provider_type: ProviderType,
223    api_key: String,
224    model: String,
225    base_url: Option<String>,
226) -> Result<Box<dyn Provider>> {
227    match provider_type {
228        ProviderType::Anthropic => {
229            let provider = anthropic::AnthropicProvider::new(
230                api_key,
231                model,
232                base_url.unwrap_or_else(|| "https://api.anthropic.com".to_string()),
233            );
234            Ok(Box::new(provider))
235        }
236        ProviderType::OpenAI => {
237            let provider = openai::OpenAIProvider::new(
238                api_key,
239                model,
240                base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
241            );
242            Ok(Box::new(provider))
243        }
244    }
245}
246
247/// Infer provider type from model name.
248/// Returns Anthropic for Claude models, OpenAI for GPT models.
249pub fn infer_provider_type(model: &str) -> ProviderType {
250    let lower = model.to_lowercase();
251    if lower.contains("claude") || lower.contains("opus") || lower.contains("sonnet") || lower.contains("haiku") {
252        ProviderType::Anthropic
253    } else if lower.contains("gpt") || lower.contains("o1") || lower.contains("o3") || lower.contains("o4") {
254        ProviderType::OpenAI
255    } else {
256        // Default to Anthropic for unknown models
257        ProviderType::Anthropic
258    }
259}