Skip to main content

matrixcode_core/providers/
openai.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::models::context_window_for;
6use crate::tools::ToolDefinition;
7
8use super::{
9    ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
10    Usage,
11};
12
13pub struct OpenAIProvider {
14    api_key: String,
15    model: String,
16    base_url: String,
17    client: reqwest::Client,
18    /// Extra headers from config
19    extra_headers: Vec<(String, String)>,
20}
21
22impl OpenAIProvider {
23    pub fn new(api_key: String, model: String, base_url: String) -> Self {
24        Self::with_headers(api_key, model, base_url, None)
25    }
26
27    pub fn with_headers(
28        api_key: String,
29        model: String,
30        base_url: String,
31        extra_headers: Option<std::collections::HashMap<String, String>>,
32    ) -> Self {
33        let client = reqwest::Client::builder()
34            .timeout(std::time::Duration::from_secs(120))
35            .connect_timeout(std::time::Duration::from_secs(10))
36            .build()
37            .unwrap_or_else(|_| reqwest::Client::new());
38        let extra_headers: Vec<(String, String)> = extra_headers
39            .map(|h| h.into_iter().collect())
40            .unwrap_or_default();
41        Self {
42            api_key,
43            model,
44            base_url,
45            client,
46            extra_headers,
47        }
48    }
49
50    fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
51        let mut result = Vec::new();
52
53        if let Some(sys) = system {
54            result.push(json!({"role": "system", "content": sys}));
55        }
56
57        for msg in messages {
58            match (&msg.role, &msg.content) {
59                (Role::System, _) => {}
60                (Role::User, MessageContent::Text(text)) => {
61                    result.push(json!({"role": "user", "content": text}));
62                }
63                (Role::Assistant, MessageContent::Text(text)) => {
64                    result.push(json!({"role": "assistant", "content": text}));
65                }
66                (Role::Assistant, MessageContent::Blocks(blocks)) => {
67                    let mut tool_calls = Vec::new();
68                    let mut text_parts = Vec::new();
69
70                    for block in blocks {
71                        match block {
72                            ContentBlock::Text { text } => text_parts.push(text.clone()),
73                            ContentBlock::ToolUse { id, name, input } => {
74                                tool_calls.push(json!({
75                                    "id": id,
76                                    "type": "function",
77                                    "function": {
78                                        "name": name,
79                                        "arguments": input.to_string(),
80                                    }
81                                }));
82                            }
83                            ContentBlock::Thinking { .. } => {}
84                            _ => {}
85                        }
86                    }
87
88                    let mut msg_obj = json!({"role": "assistant"});
89                    if !text_parts.is_empty() {
90                        msg_obj["content"] = json!(text_parts.join("\n"));
91                    }
92                    if !tool_calls.is_empty() {
93                        msg_obj["tool_calls"] = json!(tool_calls);
94                    }
95                    result.push(msg_obj);
96                }
97                (Role::Tool, MessageContent::Blocks(blocks)) => {
98                    self.push_tool_results(blocks, &mut result);
99                }
100                (Role::User, MessageContent::Blocks(blocks)) => {
101                    // Check if this is a tool result message (agent wraps tool results as User role)
102                    if blocks
103                        .iter()
104                        .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
105                    {
106                        // Emit as OpenAI tool messages
107                        self.push_tool_results(blocks, &mut result);
108                    } else {
109                        // Regular user message with blocks
110                        let text: String = blocks
111                            .iter()
112                            .filter_map(|b| match b {
113                                ContentBlock::Text { text } => Some(text.as_str()),
114                                _ => None,
115                            })
116                            .collect::<Vec<_>>()
117                            .join("\n");
118                        result.push(json!({"role": "user", "content": text}));
119                    }
120                }
121                _ => {}
122            }
123        }
124
125        result
126    }
127
128    /// Push tool result blocks to message array
129    fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
130        for block in blocks {
131            if let ContentBlock::ToolResult {
132                tool_use_id,
133                content,
134            } = block
135            {
136                result.push(json!({
137                    "role": "tool",
138                    "tool_call_id": tool_use_id,
139                    "content": content,
140                }));
141            }
142        }
143    }
144
145    fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
146        tools
147            .iter()
148            .map(|t| {
149                json!({
150                    "type": "function",
151                    "function": {
152                        "name": t.name,
153                        "description": t.description,
154                        "parameters": t.parameters,
155                    }
156                })
157            })
158            .collect()
159    }
160}
161
162#[async_trait]
163impl Provider for OpenAIProvider {
164    fn context_size(&self) -> Option<u32> {
165        context_window_for(&self.model)
166    }
167
168    fn clone_box(&self) -> Box<dyn Provider> {
169        Box::new(Self {
170            api_key: self.api_key.clone(),
171            model: self.model.clone(),
172            base_url: self.base_url.clone(),
173            client: reqwest::Client::new(),
174            extra_headers: self.extra_headers.clone(),
175        })
176    }
177
178    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
179        let messages = self.convert_messages(&request.messages, request.system.as_deref());
180
181        let mut body = json!({
182            "model": self.model,
183            "messages": messages,
184            "max_completion_tokens": request.max_tokens,
185        });
186
187        if !request.tools.is_empty() {
188            body["tools"] = json!(self.convert_tools(&request.tools));
189        }
190
191        let url = format!("{}/chat/completions", self.base_url);
192        let mut req = self
193            .client
194            .post(&url)
195            .header("Authorization", format!("Bearer {}", self.api_key))
196            .header("Content-Type", "application/json")
197            .json(&body);
198
199        // Add extra headers from config
200        for (name, value) in &self.extra_headers {
201            req = req.header(name, value);
202        }
203
204        let response = req.send().await?;
205
206        let status = response.status();
207        let response_body: Value = response.json().await?;
208
209        if !status.is_success() {
210            let err_msg = response_body["error"]["message"]
211                .as_str()
212                .unwrap_or("unknown error");
213            anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
214        }
215
216        let choice = &response_body["choices"][0];
217        let message = &choice["message"];
218        let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
219
220        let stop_reason = match finish_reason {
221            "tool_calls" => StopReason::ToolUse,
222            "length" => StopReason::MaxTokens,
223            _ => StopReason::EndTurn,
224        };
225
226        let mut content = Vec::new();
227
228        let usage_blob = &response_body["usage"];
229        let usage = Usage {
230            input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
231            output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
232            cache_creation_input_tokens: 0,
233            cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
234                .as_u64()
235                .unwrap_or(0) as u32,
236        };
237
238        if let Some(text) = message["content"].as_str()
239            && !text.is_empty()
240        {
241            content.push(ContentBlock::Text {
242                text: text.to_string(),
243            });
244        }
245
246        if let Some(tool_calls) = message["tool_calls"].as_array() {
247            for tc in tool_calls {
248                let id = tc["id"].as_str().unwrap_or_default().to_string();
249                let name = tc["function"]["name"]
250                    .as_str()
251                    .unwrap_or_default()
252                    .to_string();
253                let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
254                let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
255
256                content.push(ContentBlock::ToolUse { id, name, input });
257            }
258
259            if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
260                return Ok(ChatResponse {
261                    content,
262                    stop_reason: StopReason::ToolUse,
263                    usage: usage.clone(),
264                });
265            }
266        }
267
268        Ok(ChatResponse {
269            content,
270            stop_reason,
271            usage,
272        })
273    }
274}