Skip to main content

matrixcode_core/providers/
mod.rs

1pub mod anthropic;
2pub mod openai;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use tokio::sync::mpsc;
8
9use crate::tools::ToolDefinition;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Message {
13    pub role: Role,
14    pub content: MessageContent,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(rename_all = "lowercase")]
19pub enum Role {
20    System,
21    User,
22    Assistant,
23    Tool,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27#[serde(untagged)]
28pub enum MessageContent {
29    Text(String),
30    Blocks(Vec<ContentBlock>),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34#[serde(tag = "type")]
35pub enum ContentBlock {
36    #[serde(rename = "text")]
37    Text { text: String },
38    #[serde(rename = "tool_use")]
39    ToolUse {
40        id: String,
41        name: String,
42        input: serde_json::Value,
43    },
44    #[serde(rename = "tool_result")]
45    ToolResult {
46        tool_use_id: String,
47        content: String,
48    },
49    /// Anthropic extended-thinking block. `signature` is required when sending
50    /// the block back to the API in a follow-up turn.
51    #[serde(rename = "thinking")]
52    Thinking {
53        thinking: String,
54        #[serde(skip_serializing_if = "Option::is_none")]
55        signature: Option<String>,
56    },
57    /// Server-side tool use (e.g., web_search_tool). The server executes
58    /// the tool and returns results directly without client intervention.
59    #[serde(rename = "server_tool_use")]
60    ServerToolUse {
61        id: String,
62        name: String,
63        input: serde_json::Value,
64    },
65    /// Result from a server-side web search tool.
66    #[serde(rename = "web_search_tool_result")]
67    WebSearchResult {
68        tool_use_id: String,
69        content: WebSearchContent,
70    },
71}
72
73/// Content returned by the server-side web search tool.
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
75pub struct WebSearchContent {
76    pub results: Vec<WebSearchResultItem>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
80pub struct WebSearchResultItem {
81    pub title: Option<String>,
82    pub url: String,
83    pub encrypted_content: Option<String>,
84    pub snippet: Option<String>,
85}
86
87/// Server-side tool definition. These tools are executed by the API provider
88/// rather than by the client. Currently only web_search_tool is supported.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ServerTool {
91    #[serde(rename = "type")]
92    pub tool_type: String,
93    pub name: String,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub max_uses: Option<u32>,
96}
97
98impl ServerTool {
99    /// Create a new web search server tool.
100    pub fn web_search(max_uses: Option<u32>) -> Self {
101        Self {
102            tool_type: "web_search_tool".to_string(),
103            name: "web_search".to_string(),
104            max_uses,
105        }
106    }
107}
108
109#[derive(Debug, Clone)]
110pub struct ChatRequest {
111    pub messages: Vec<Message>,
112    pub tools: Vec<ToolDefinition>,
113    pub system: Option<String>,
114    pub think: bool,
115    /// Maximum output tokens for the response.
116    pub max_tokens: u32,
117    /// Server-side tools that are executed by the API provider.
118    pub server_tools: Vec<ServerTool>,
119    /// Enable prompt caching for Anthropic provider.
120    pub enable_caching: bool,
121}
122
123#[derive(Debug, Clone)]
124pub struct ChatResponse {
125    pub content: Vec<ContentBlock>,
126    pub stop_reason: StopReason,
127    pub usage: Usage,
128}
129
130/// Token accounting for one provider turn. `input_tokens` already includes
131/// cached/uncached portions combined — when providers expose cache details
132/// separately we capture them here so callers can report cache effectiveness.
133#[derive(Debug, Clone, Default, PartialEq, Eq)]
134pub struct Usage {
135    pub input_tokens: u32,
136    pub output_tokens: u32,
137    pub cache_creation_input_tokens: u32,
138    pub cache_read_input_tokens: u32,
139}
140
141#[derive(Debug, Clone, PartialEq)]
142pub enum StopReason {
143    EndTurn,
144    ToolUse,
145    MaxTokens,
146}
147
148/// Incremental events emitted during a streaming chat turn.
149#[derive(Debug, Clone)]
150pub enum StreamEvent {
151    /// First byte received from the server — agent should stop any waiting spinner.
152    FirstByte,
153    /// Extended-thinking text delta (Anthropic thinking block).
154    ThinkingDelta(String),
155    /// Visible assistant text delta.
156    TextDelta(String),
157    /// A new tool_use block started.
158    ToolUseStart { id: String, name: String },
159    /// Incremental progress for the current tool_use block's JSON input.
160    /// `bytes_so_far` is the total accumulated size of the partial JSON
161    /// received for this block — useful for driving progress indicators
162    /// while the model streams large arguments (e.g. a full file body).
163    ToolInputDelta { bytes_so_far: usize },
164    /// Final turn result — includes the full assembled content blocks.
165    Done(ChatResponse),
166    /// Fatal error during streaming.
167    Error(String),
168}
169
170#[async_trait]
171pub trait Provider: Send + Sync {
172    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
173
174    /// Best-effort context window size (in tokens) for the configured model.
175    /// `None` if the provider cannot infer it; callers should treat that as
176    /// "don't render a fullness bar".
177    fn context_size(&self) -> Option<u32> {
178        None
179    }
180
181    /// Stream a chat turn. Default impl wraps `chat` and emits one `Done` event,
182    /// so providers without native streaming still work (no incremental thinking).
183    async fn chat_stream(&self, request: ChatRequest) -> Result<mpsc::Receiver<StreamEvent>> {
184        let (tx, rx) = mpsc::channel(32);
185        let response = self.chat(request).await?;
186        let _ = tx.send(StreamEvent::FirstByte).await;
187        for block in &response.content {
188            if let ContentBlock::Text { text } = block {
189                let _ = tx.send(StreamEvent::TextDelta(text.clone())).await;
190            }
191        }
192        let _ = tx.send(StreamEvent::Done(response)).await;
193        Ok(rx)
194    }
195
196    /// Clone the provider into a boxed type.
197    fn clone_box(&self) -> Box<dyn Provider>;
198}
199
200impl Clone for Box<dyn Provider> {
201    fn clone(&self) -> Self {
202        self.clone_box()
203    }
204}