matrixcode_core/providers/
openai.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4use std::sync::Arc;
5use std::time::Duration;
6
7use crate::constants::{
8 DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_READ_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS,
9};
10use crate::models::context_window_for;
11use crate::tools::ToolDefinition;
12
13use super::{
14 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
15 Usage,
16};
17
18pub struct OpenAIProvider {
19 api_key: String,
20 model: String,
21 base_url: String,
22 client: reqwest::Client,
23 extra_headers: Vec<(String, String)>,
25}
26
27impl OpenAIProvider {
28 pub fn new(api_key: String, model: String, base_url: String) -> Self {
29 Self::with_headers(api_key, model, base_url, None)
30 }
31
32 pub fn with_headers(
33 api_key: String,
34 model: String,
35 base_url: String,
36 extra_headers: Option<std::collections::HashMap<String, String>>,
37 ) -> Self {
38 let client = reqwest::Client::builder()
43 .timeout(Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS))
44 .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
45 .read_timeout(Duration::from_secs(DEFAULT_READ_TIMEOUT_SECS))
46 .build()
47 .unwrap_or_else(|_| reqwest::Client::new());
48 let extra_headers: Vec<(String, String)> = extra_headers
49 .map(|h| h.into_iter().collect())
50 .unwrap_or_default();
51 Self {
52 api_key,
53 model,
54 base_url,
55 client,
56 extra_headers,
57 }
58 }
59
60 fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
61 let mut result = Vec::new();
62
63 if let Some(sys) = system {
64 result.push(json!({"role": "system", "content": sys}));
65 }
66
67 for msg in messages {
68 match (&msg.role, &msg.content) {
69 (Role::System, _) => {}
70 (Role::User, MessageContent::Text(text)) => {
71 result.push(json!({"role": "user", "content": text}));
72 }
73 (Role::Assistant, MessageContent::Text(text)) => {
74 result.push(json!({"role": "assistant", "content": text}));
75 }
76 (Role::Assistant, MessageContent::Blocks(blocks)) => {
77 let mut tool_calls = Vec::new();
78 let mut text_parts = Vec::new();
79 let mut reasoning_parts = Vec::new();
80
81 for block in blocks {
82 match block {
83 ContentBlock::Thinking { thinking, .. } => {
84 reasoning_parts.push(thinking.clone());
86 }
87 ContentBlock::Text { text } => text_parts.push(text.clone()),
88 ContentBlock::ToolUse { id, name, input } => {
89 tool_calls.push(json!({
90 "id": id,
91 "type": "function",
92 "function": {
93 "name": name,
94 "arguments": input.to_string(),
95 }
96 }));
97 }
98 _ => {}
99 }
100 }
101
102 let mut msg_obj = json!({"role": "assistant"});
103 if !reasoning_parts.is_empty() {
105 msg_obj["reasoning_content"] = json!(reasoning_parts.join("\n"));
106 }
107 if !text_parts.is_empty() {
108 msg_obj["content"] = json!(text_parts.join("\n"));
109 }
110 if !tool_calls.is_empty() {
111 msg_obj["tool_calls"] = json!(tool_calls);
112 }
113 result.push(msg_obj);
114 }
115 (Role::Tool, MessageContent::Blocks(blocks)) => {
116 self.push_tool_results(blocks, &mut result);
117 }
118 (Role::User, MessageContent::Blocks(blocks)) => {
119 if blocks
121 .iter()
122 .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
123 {
124 self.push_tool_results(blocks, &mut result);
126 } else {
127 let text: String = blocks
129 .iter()
130 .filter_map(|b| match b {
131 ContentBlock::Text { text } => Some(text.as_str()),
132 _ => None,
133 })
134 .collect::<Vec<_>>()
135 .join("\n");
136 result.push(json!({"role": "user", "content": text}));
137 }
138 }
139 _ => {}
140 }
141 }
142
143 result
144 }
145
146 fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
148 for block in blocks {
149 if let ContentBlock::ToolResult {
150 tool_use_id,
151 content,
152 } = block
153 {
154 result.push(json!({
155 "role": "tool",
156 "tool_call_id": tool_use_id,
157 "content": content,
158 }));
159 }
160 }
161 }
162
163 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
164 tools
165 .iter()
166 .map(|t| {
167 json!({
168 "type": "function",
169 "function": {
170 "name": t.name,
171 "description": t.description,
172 "parameters": t.parameters,
173 }
174 })
175 })
176 .collect()
177 }
178}
179
180#[async_trait]
181impl Provider for OpenAIProvider {
182 fn context_size(&self) -> Option<u32> {
183 context_window_for(&self.model)
184 }
185
186 fn model_name(&self) -> &str {
187 &self.model
188 }
189
190 fn clone_box(&self) -> Box<dyn Provider> {
191 Box::new(Self {
192 api_key: self.api_key.clone(),
193 model: self.model.clone(),
194 base_url: self.base_url.clone(),
195 client: reqwest::Client::new(),
196 extra_headers: self.extra_headers.clone(),
197 })
198 }
199
200 fn clone_arc(&self) -> Arc<dyn Provider> {
201 Arc::new(Self {
202 api_key: self.api_key.clone(),
203 model: self.model.clone(),
204 base_url: self.base_url.clone(),
205 client: reqwest::Client::new(),
206 extra_headers: self.extra_headers.clone(),
207 })
208 }
209
210 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
211 let messages = self.convert_messages(&request.messages, request.system.as_deref());
212
213 let mut body = json!({
214 "model": self.model,
215 "messages": messages,
216 "max_completion_tokens": request.max_tokens,
217 });
218
219 if !request.tools.is_empty() {
220 body["tools"] = json!(self.convert_tools(&request.tools));
221 }
222
223 let url = format!("{}/chat/completions", self.base_url);
224
225 crate::debug::debug_log()
227 .api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
228
229 let mut req = self
230 .client
231 .post(&url)
232 .header("Authorization", format!("Bearer {}", self.api_key))
233 .header("Content-Type", "application/json")
234 .json(&body);
235
236 for (name, value) in &self.extra_headers {
238 req = req.header(name, value);
239 }
240
241 let response = req
242 .send()
243 .await
244 .map_err(|e| anyhow::anyhow!("HTTP request failed: {} (url: {})", e, url))?;
245
246 let status = response.status();
247 let response_body: Value = response
248 .json()
249 .await
250 .map_err(|e| anyhow::anyhow!("Failed to parse response JSON: {}", e))?;
251
252 crate::debug::debug_log().api_response(
254 status.as_u16(),
255 &serde_json::to_string(&response_body).unwrap_or_default(),
256 );
257
258 if !status.is_success() {
259 let err_msg = response_body["error"]["message"]
260 .as_str()
261 .unwrap_or_else(|| response_body["error"].as_str().unwrap_or("unknown error"));
262 anyhow::bail!("API error ({}): {}", status, err_msg);
263 }
264
265 let choice = &response_body["choices"][0];
266 let message = &choice["message"];
267 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
268
269 let stop_reason = match finish_reason {
270 "tool_calls" => StopReason::ToolUse,
271 "length" => StopReason::MaxTokens,
272 _ => StopReason::EndTurn,
273 };
274
275 let mut content = Vec::new();
276
277 let usage_blob = &response_body["usage"];
278 let usage = Usage {
279 input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
280 output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
281 cache_creation_input_tokens: 0,
282 cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
283 .as_u64()
284 .unwrap_or(0) as u32,
285 };
286
287 if let Some(reasoning) = message["reasoning_content"].as_str()
289 && !reasoning.is_empty()
290 {
291 content.push(ContentBlock::Thinking {
292 thinking: reasoning.to_string(),
293 signature: None,
294 });
295 }
296
297 if let Some(text) = message["content"].as_str()
298 && !text.is_empty()
299 {
300 content.push(ContentBlock::Text {
301 text: text.to_string(),
302 });
303 }
304
305 if let Some(tool_calls) = message["tool_calls"].as_array() {
306 for tc in tool_calls {
307 let id = tc["id"].as_str().unwrap_or_default().to_string();
308 let name = tc["function"]["name"]
309 .as_str()
310 .unwrap_or_default()
311 .to_string();
312 let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
313 let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
314
315 content.push(ContentBlock::ToolUse { id, name, input });
316 }
317
318 if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
319 return Ok(ChatResponse {
320 content,
321 stop_reason: StopReason::ToolUse,
322 usage: usage.clone(),
323 });
324 }
325 }
326
327 Ok(ChatResponse {
328 content,
329 stop_reason,
330 usage,
331 })
332 }
333}