matrixcode_core/providers/
openai.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::models::context_window_for;
6use crate::tools::ToolDefinition;
7
8use super::{
9 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
10 Usage,
11};
12
13pub struct OpenAIProvider {
14 api_key: String,
15 model: String,
16 base_url: String,
17 client: reqwest::Client,
18 extra_headers: Vec<(String, String)>,
20}
21
22impl OpenAIProvider {
23 pub fn new(api_key: String, model: String, base_url: String) -> Self {
24 Self::with_headers(api_key, model, base_url, None)
25 }
26
27 pub fn with_headers(
28 api_key: String,
29 model: String,
30 base_url: String,
31 extra_headers: Option<std::collections::HashMap<String, String>>,
32 ) -> Self {
33 let client = reqwest::Client::builder()
38 .timeout(std::time::Duration::from_secs(300))
39 .connect_timeout(std::time::Duration::from_secs(10))
40 .read_timeout(std::time::Duration::from_secs(60))
41 .build()
42 .unwrap_or_else(|_| reqwest::Client::new());
43 let extra_headers: Vec<(String, String)> = extra_headers
44 .map(|h| h.into_iter().collect())
45 .unwrap_or_default();
46 Self {
47 api_key,
48 model,
49 base_url,
50 client,
51 extra_headers,
52 }
53 }
54
55 fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
56 let mut result = Vec::new();
57
58 if let Some(sys) = system {
59 result.push(json!({"role": "system", "content": sys}));
60 }
61
62 for msg in messages {
63 match (&msg.role, &msg.content) {
64 (Role::System, _) => {}
65 (Role::User, MessageContent::Text(text)) => {
66 result.push(json!({"role": "user", "content": text}));
67 }
68 (Role::Assistant, MessageContent::Text(text)) => {
69 result.push(json!({"role": "assistant", "content": text}));
70 }
71 (Role::Assistant, MessageContent::Blocks(blocks)) => {
72 let mut tool_calls = Vec::new();
73 let mut text_parts = Vec::new();
74 let mut reasoning_parts = Vec::new();
75
76 for block in blocks {
77 match block {
78 ContentBlock::Thinking { thinking, .. } => {
79 reasoning_parts.push(thinking.clone());
81 }
82 ContentBlock::Text { text } => text_parts.push(text.clone()),
83 ContentBlock::ToolUse { id, name, input } => {
84 tool_calls.push(json!({
85 "id": id,
86 "type": "function",
87 "function": {
88 "name": name,
89 "arguments": input.to_string(),
90 }
91 }));
92 }
93 _ => {}
94 }
95 }
96
97 let mut msg_obj = json!({"role": "assistant"});
98 if !reasoning_parts.is_empty() {
100 msg_obj["reasoning_content"] = json!(reasoning_parts.join("\n"));
101 }
102 if !text_parts.is_empty() {
103 msg_obj["content"] = json!(text_parts.join("\n"));
104 }
105 if !tool_calls.is_empty() {
106 msg_obj["tool_calls"] = json!(tool_calls);
107 }
108 result.push(msg_obj);
109 }
110 (Role::Tool, MessageContent::Blocks(blocks)) => {
111 self.push_tool_results(blocks, &mut result);
112 }
113 (Role::User, MessageContent::Blocks(blocks)) => {
114 if blocks
116 .iter()
117 .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
118 {
119 self.push_tool_results(blocks, &mut result);
121 } else {
122 let text: String = blocks
124 .iter()
125 .filter_map(|b| match b {
126 ContentBlock::Text { text } => Some(text.as_str()),
127 _ => None,
128 })
129 .collect::<Vec<_>>()
130 .join("\n");
131 result.push(json!({"role": "user", "content": text}));
132 }
133 }
134 _ => {}
135 }
136 }
137
138 result
139 }
140
141 fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
143 for block in blocks {
144 if let ContentBlock::ToolResult {
145 tool_use_id,
146 content,
147 } = block
148 {
149 result.push(json!({
150 "role": "tool",
151 "tool_call_id": tool_use_id,
152 "content": content,
153 }));
154 }
155 }
156 }
157
158 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
159 tools
160 .iter()
161 .map(|t| {
162 json!({
163 "type": "function",
164 "function": {
165 "name": t.name,
166 "description": t.description,
167 "parameters": t.parameters,
168 }
169 })
170 })
171 .collect()
172 }
173}
174
175#[async_trait]
176impl Provider for OpenAIProvider {
177 fn context_size(&self) -> Option<u32> {
178 context_window_for(&self.model)
179 }
180
181 fn model_name(&self) -> &str {
182 &self.model
183 }
184
185 fn clone_box(&self) -> Box<dyn Provider> {
186 Box::new(Self {
187 api_key: self.api_key.clone(),
188 model: self.model.clone(),
189 base_url: self.base_url.clone(),
190 client: reqwest::Client::new(),
191 extra_headers: self.extra_headers.clone(),
192 })
193 }
194
195 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
196 let messages = self.convert_messages(&request.messages, request.system.as_deref());
197
198 let mut body = json!({
199 "model": self.model,
200 "messages": messages,
201 "max_completion_tokens": request.max_tokens,
202 });
203
204 if !request.tools.is_empty() {
205 body["tools"] = json!(self.convert_tools(&request.tools));
206 }
207
208 let url = format!("{}/chat/completions", self.base_url);
209
210 crate::debug::debug_log().api_request(&url, &serde_json::to_string(&body).unwrap_or_default());
212
213 let mut req = self
214 .client
215 .post(&url)
216 .header("Authorization", format!("Bearer {}", self.api_key))
217 .header("Content-Type", "application/json")
218 .json(&body);
219
220 for (name, value) in &self.extra_headers {
222 req = req.header(name, value);
223 }
224
225 let response = req.send().await?;
226
227 let status = response.status();
228 let response_body: Value = response.json().await?;
229
230 crate::debug::debug_log().api_response(status.as_u16(), &serde_json::to_string(&response_body).unwrap_or_default());
232
233 if !status.is_success() {
234 let err_msg = response_body["error"]["message"]
235 .as_str()
236 .unwrap_or("unknown error");
237 anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
238 }
239
240 let choice = &response_body["choices"][0];
241 let message = &choice["message"];
242 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
243
244 let stop_reason = match finish_reason {
245 "tool_calls" => StopReason::ToolUse,
246 "length" => StopReason::MaxTokens,
247 _ => StopReason::EndTurn,
248 };
249
250 let mut content = Vec::new();
251
252 let usage_blob = &response_body["usage"];
253 let usage = Usage {
254 input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
255 output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
256 cache_creation_input_tokens: 0,
257 cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
258 .as_u64()
259 .unwrap_or(0) as u32,
260 };
261
262 if let Some(reasoning) = message["reasoning_content"].as_str()
264 && !reasoning.is_empty()
265 {
266 content.push(ContentBlock::Thinking {
267 thinking: reasoning.to_string(),
268 signature: None,
269 });
270 }
271
272 if let Some(text) = message["content"].as_str()
273 && !text.is_empty()
274 {
275 content.push(ContentBlock::Text {
276 text: text.to_string(),
277 });
278 }
279
280 if let Some(tool_calls) = message["tool_calls"].as_array() {
281 for tc in tool_calls {
282 let id = tc["id"].as_str().unwrap_or_default().to_string();
283 let name = tc["function"]["name"]
284 .as_str()
285 .unwrap_or_default()
286 .to_string();
287 let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
288 let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
289
290 content.push(ContentBlock::ToolUse { id, name, input });
291 }
292
293 if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
294 return Ok(ChatResponse {
295 content,
296 stop_reason: StopReason::ToolUse,
297 usage: usage.clone(),
298 });
299 }
300 }
301
302 Ok(ChatResponse {
303 content,
304 stop_reason,
305 usage,
306 })
307 }
308}