matrixcode_core/providers/
openai.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::constants::{
8 DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS,
9};
10use crate::models::context_window_for;
11use crate::tools::ToolDefinition;
12
13use super::{
14 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
15 Usage,
16};
17
18pub struct OpenAIProvider {
19 api_key: String,
20 model: String,
21 base_url: String,
22 client: reqwest::Client,
23 extra_headers: Vec<(String, String)>,
25}
26
27impl OpenAIProvider {
28 pub fn new(api_key: String, model: String, base_url: String) -> Self {
29 Self::with_headers(api_key, model, base_url, None)
30 }
31
32 pub fn with_headers(
33 api_key: String,
34 model: String,
35 base_url: String,
36 extra_headers: Option<std::collections::HashMap<String, String>>,
37 ) -> Self {
38 let client = reqwest::Client::builder()
43 .timeout(Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS))
44 .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
45 .read_timeout(Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
46 .build()
47 .unwrap_or_else(|_| reqwest::Client::new());
48 let extra_headers: Vec<(String, String)> = extra_headers
49 .map(|h| h.into_iter().collect())
50 .unwrap_or_default();
51 Self {
52 api_key,
53 model,
54 base_url,
55 client,
56 extra_headers,
57 }
58 }
59
60 fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
61 let mut result = Vec::new();
62
63 if let Some(sys) = system {
64 result.push(json!({"role": "system", "content": sys}));
65 }
66
67 for msg in messages {
68 match (&msg.role, &msg.content) {
69 (Role::System, _) => {}
70 (Role::User, MessageContent::Text(text)) => {
71 result.push(json!({"role": "user", "content": text}));
72 }
73 (Role::Assistant, MessageContent::Text(text)) => {
74 result.push(json!({"role": "assistant", "content": text}));
75 }
76 (Role::Assistant, MessageContent::Blocks(blocks)) => {
77 let mut tool_calls = Vec::new();
78 let mut text_parts = Vec::new();
79 for block in blocks {
84 match block {
85 ContentBlock::Thinking { .. } => {
87 continue;
88 }
89 ContentBlock::Text { text } => text_parts.push(text.clone()),
90 ContentBlock::ToolUse { id, name, input } => {
91 tool_calls.push(json!({
92 "id": id,
93 "type": "function",
94 "function": {
95 "name": name,
96 "arguments": input.to_string(),
97 }
98 }));
99 }
100 _ => {}
101 }
102 }
103
104 let mut msg_obj = json!({"role": "assistant"});
105 if !text_parts.is_empty() {
107 msg_obj["content"] = json!(text_parts.join("\n"));
108 }
109 if !tool_calls.is_empty() {
110 msg_obj["tool_calls"] = json!(tool_calls);
111 }
112 result.push(msg_obj);
113 }
114 (Role::Tool, MessageContent::Blocks(blocks)) => {
115 self.push_tool_results(blocks, &mut result);
116 }
117 (Role::User, MessageContent::Blocks(blocks)) => {
118 if blocks
120 .iter()
121 .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
122 {
123 self.push_tool_results(blocks, &mut result);
125 } else {
126 let text: String = blocks
128 .iter()
129 .filter_map(|b| match b {
130 ContentBlock::Text { text } => Some(text.as_str()),
131 _ => None,
132 })
133 .collect::<Vec<_>>()
134 .join("\n");
135 result.push(json!({"role": "user", "content": text}));
136 }
137 }
138 _ => {}
139 }
140 }
141
142 result
143 }
144
145 fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
147 for block in blocks {
148 if let ContentBlock::ToolResult {
149 tool_use_id,
150 content,
151 } = block
152 {
153 result.push(json!({
154 "role": "tool",
155 "tool_call_id": tool_use_id,
156 "content": content,
157 }));
158 }
159 }
160 }
161
162 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
163 tools
164 .iter()
165 .map(|t| {
166 json!({
167 "type": "function",
168 "function": {
169 "name": t.name,
170 "description": t.description,
171 "parameters": t.parameters,
172 }
173 })
174 })
175 .collect()
176 }
177}
178
179#[async_trait]
180impl Provider for OpenAIProvider {
181 fn context_size(&self) -> Option<u32> {
182 context_window_for(&self.model)
183 }
184
185 fn model_name(&self) -> &str {
186 &self.model
187 }
188
189 fn clone_box(&self) -> Box<dyn Provider> {
190 Box::new(Self {
191 api_key: self.api_key.clone(),
192 model: self.model.clone(),
193 base_url: self.base_url.clone(),
194 client: reqwest::Client::new(),
195 extra_headers: self.extra_headers.clone(),
196 })
197 }
198
199 fn clone_arc(&self) -> Arc<dyn Provider> {
200 Arc::new(Self {
201 api_key: self.api_key.clone(),
202 model: self.model.clone(),
203 base_url: self.base_url.clone(),
204 client: reqwest::Client::new(),
205 extra_headers: self.extra_headers.clone(),
206 })
207 }
208
209 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
210 let messages = self.convert_messages(&request.messages, request.system.as_deref());
211
212 let mut body = json!({
213 "model": self.model,
214 "messages": messages,
215 "max_completion_tokens": request.max_tokens,
216 });
217
218 if !request.tools.is_empty() {
219 body["tools"] = json!(self.convert_tools(&request.tools));
220 }
221
222 let url = format!("{}/chat/completions", self.base_url);
223
224 crate::debug::debug_log()
226 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
227
228 let mut req = self
229 .client
230 .post(&url)
231 .header("Authorization", format!("Bearer {}", self.api_key))
232 .header("Content-Type", "application/json")
233 .json(&body);
234
235 for (name, value) in &self.extra_headers {
237 req = req.header(name, value);
238 }
239
240 let response = req
241 .send()
242 .await
243 .map_err(|e| anyhow::anyhow!("HTTP request failed: {} (url: {})", e, url))?;
244
245 let status = response.status();
246 let response_body: Value = response
247 .json()
248 .await
249 .map_err(|e| anyhow::anyhow!("Failed to parse response JSON: {}", e))?;
250
251 crate::debug::debug_log().api_response(
253 status.as_u16(),
254 &serde_json::to_string(&response_body).unwrap_or_default(),
255 );
256
257 if !status.is_success() {
258 let err_msg = response_body["error"]["message"]
259 .as_str()
260 .unwrap_or_else(|| response_body["error"].as_str().unwrap_or("unknown error"));
261 anyhow::bail!("API error ({}): {}", status, err_msg);
262 }
263
264 let choice = &response_body["choices"][0];
265 let message = &choice["message"];
266 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
267
268 let stop_reason = match finish_reason {
269 "tool_calls" => StopReason::ToolUse,
270 "length" => StopReason::MaxTokens,
271 _ => StopReason::EndTurn,
272 };
273
274 let mut content = Vec::new();
275
276 let usage_blob = &response_body["usage"];
277 let usage = Usage {
278 input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
279 output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
280 cache_creation_input_tokens: 0,
281 cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
282 .as_u64()
283 .unwrap_or(0) as u32,
284 };
285
286 if let Some(reasoning) = message["reasoning_content"].as_str()
288 && !reasoning.is_empty()
289 {
290 content.push(ContentBlock::Thinking {
291 thinking: reasoning.to_string(),
292 signature: None,
293 });
294 }
295
296 if let Some(text) = message["content"].as_str()
297 && !text.is_empty()
298 {
299 content.push(ContentBlock::Text {
300 text: text.to_string(),
301 });
302 }
303
304 if let Some(tool_calls) = message["tool_calls"].as_array() {
305 for tc in tool_calls {
306 let id = tc["id"].as_str().unwrap_or_default().to_string();
307 let name = tc["function"]["name"]
308 .as_str()
309 .unwrap_or_default()
310 .to_string();
311 let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
312 let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
313
314 content.push(ContentBlock::ToolUse { id, name, input });
315 }
316
317 if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
318 return Ok(ChatResponse {
319 content,
320 stop_reason: StopReason::ToolUse,
321 usage: usage.clone(),
322 });
323 }
324 }
325
326 Ok(ChatResponse {
327 content,
328 stop_reason,
329 usage,
330 })
331 }
332}