matrixcode_core/providers/
openai.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::tools::ToolDefinition;
6
7use super::{
8 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
9 Usage,
10};
11
12pub struct OpenAIProvider {
13 api_key: String,
14 model: String,
15 base_url: String,
16 client: reqwest::Client,
17}
18
19impl OpenAIProvider {
20 pub fn new(api_key: String, model: String, base_url: String) -> Self {
21 Self {
22 api_key,
23 model,
24 base_url,
25 client: reqwest::Client::new(),
26 }
27 }
28
29 fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
30 let mut result = Vec::new();
31
32 if let Some(sys) = system {
33 result.push(json!({"role": "system", "content": sys}));
34 }
35
36 for msg in messages {
37 match (&msg.role, &msg.content) {
38 (Role::System, _) => {}
39 (Role::User, MessageContent::Text(text)) => {
40 result.push(json!({"role": "user", "content": text}));
41 }
42 (Role::Assistant, MessageContent::Text(text)) => {
43 result.push(json!({"role": "assistant", "content": text}));
44 }
45 (Role::Assistant, MessageContent::Blocks(blocks)) => {
46 let mut tool_calls = Vec::new();
47 let mut text_parts = Vec::new();
48
49 for block in blocks {
50 match block {
51 ContentBlock::Text { text } => text_parts.push(text.clone()),
52 ContentBlock::ToolUse { id, name, input } => {
53 tool_calls.push(json!({
54 "id": id,
55 "type": "function",
56 "function": {
57 "name": name,
58 "arguments": input.to_string(),
59 }
60 }));
61 }
62 ContentBlock::Thinking { .. } => {}
63 _ => {}
64 }
65 }
66
67 let mut msg_obj = json!({"role": "assistant"});
68 if !text_parts.is_empty() {
69 msg_obj["content"] = json!(text_parts.join("\n"));
70 }
71 if !tool_calls.is_empty() {
72 msg_obj["tool_calls"] = json!(tool_calls);
73 }
74 result.push(msg_obj);
75 }
76 (Role::Tool, MessageContent::Blocks(blocks)) => {
77 for block in blocks {
78 if let ContentBlock::ToolResult { tool_use_id, content } = block {
79 result.push(json!({
80 "role": "tool",
81 "tool_call_id": tool_use_id,
82 "content": content,
83 }));
84 }
85 }
86 }
87 (Role::User, MessageContent::Blocks(blocks)) => {
88 let tool_results: Vec<&ContentBlock> = blocks
90 .iter()
91 .filter(|b| matches!(b, ContentBlock::ToolResult { .. }))
92 .collect();
93
94 if !tool_results.is_empty() {
95 for block in blocks {
97 if let ContentBlock::ToolResult { tool_use_id, content } = block {
98 result.push(json!({
99 "role": "tool",
100 "tool_call_id": tool_use_id,
101 "content": content,
102 }));
103 }
104 }
105 } else {
106 let text: String = blocks
108 .iter()
109 .filter_map(|b| match b {
110 ContentBlock::Text { text } => Some(text.as_str()),
111 _ => None,
112 })
113 .collect::<Vec<_>>()
114 .join("\n");
115 result.push(json!({"role": "user", "content": text}));
116 }
117 }
118 _ => {}
119 }
120 }
121
122 result
123 }
124
125 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
126 tools
127 .iter()
128 .map(|t| {
129 json!({
130 "type": "function",
131 "function": {
132 "name": t.name,
133 "description": t.description,
134 "parameters": t.parameters,
135 }
136 })
137 })
138 .collect()
139 }
140}
141
142#[async_trait]
143impl Provider for OpenAIProvider {
144 fn context_size(&self) -> Option<u32> {
145 context_window_for(&self.model)
146 }
147
148 fn clone_box(&self) -> Box<dyn Provider> {
149 Box::new(Self {
150 api_key: self.api_key.clone(),
151 model: self.model.clone(),
152 base_url: self.base_url.clone(),
153 client: reqwest::Client::new(),
154 })
155 }
156
157 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
158 let messages = self.convert_messages(&request.messages, request.system.as_deref());
159
160 let mut body = json!({
161 "model": self.model,
162 "messages": messages,
163 "max_completion_tokens": request.max_tokens,
164 });
165
166 if !request.tools.is_empty() {
167 body["tools"] = json!(self.convert_tools(&request.tools));
168 }
169
170 let url = format!("{}/chat/completions", self.base_url);
171 let response = self
172 .client
173 .post(&url)
174 .header("Authorization", format!("Bearer {}", self.api_key))
175 .header("Content-Type", "application/json")
176 .json(&body)
177 .send()
178 .await?;
179
180 let status = response.status();
181 let response_body: Value = response.json().await?;
182
183 if !status.is_success() {
184 let err_msg = response_body["error"]["message"]
185 .as_str()
186 .unwrap_or("unknown error");
187 anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
188 }
189
190 let choice = &response_body["choices"][0];
191 let message = &choice["message"];
192 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
193
194 let stop_reason = match finish_reason {
195 "tool_calls" => StopReason::ToolUse,
196 "length" => StopReason::MaxTokens,
197 _ => StopReason::EndTurn,
198 };
199
200 let mut content = Vec::new();
201
202 let usage_blob = &response_body["usage"];
203 let usage = Usage {
204 input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
205 output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
206 cache_creation_input_tokens: 0,
207 cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
208 .as_u64()
209 .unwrap_or(0) as u32,
210 };
211
212 if let Some(text) = message["content"].as_str()
213 && !text.is_empty() {
214 content.push(ContentBlock::Text { text: text.to_string() });
215 }
216
217 if let Some(tool_calls) = message["tool_calls"].as_array() {
218 for tc in tool_calls {
219 let id = tc["id"].as_str().unwrap_or_default().to_string();
220 let name = tc["function"]["name"].as_str().unwrap_or_default().to_string();
221 let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
222 let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
223
224 content.push(ContentBlock::ToolUse { id, name, input });
225 }
226
227 if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
228 return Ok(ChatResponse {
229 content,
230 stop_reason: StopReason::ToolUse,
231 usage: usage.clone(),
232 });
233 }
234 }
235
236 Ok(ChatResponse {
237 content,
238 stop_reason,
239 usage,
240 })
241 }
242}
243
244fn context_window_for(model: &str) -> Option<u32> {
247 if let Ok(raw) = std::env::var("CONTEXT_SIZE")
248 && let Ok(n) = raw.trim().parse::<u32>()
249 && n > 0 {
250 return Some(n);
251 }
252 let m = model.to_ascii_lowercase();
253
254 if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
256 return Some(128_000);
257 }
258 if m.contains("gpt-4-32k") {
260 return Some(32_768);
261 }
262 if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
263 return Some(8_192);
264 }
265 if m.contains("gpt-3.5-turbo-16k") {
267 return Some(16_384);
268 }
269 if m.contains("gpt-3.5") {
270 return Some(4_096);
271 }
272 if m.contains("o1") {
274 return Some(200_000);
275 }
276 if m.contains("deepseek") {
278 if m.contains("v3") || m.contains("r1") {
279 return Some(128_000);
280 }
281 return Some(64_000);
282 }
283 if m.contains("qwen") {
285 if m.contains("qwen-max") || m.contains("qwen2.5-72b") {
286 return Some(128_000);
287 }
288 if m.contains("qwen2") {
289 return Some(32_000);
290 }
291 return Some(32_000);
292 }
293 if m.contains("llama-3") || m.contains("llama3") {
295 if m.contains("70b") || m.contains("405b") {
296 return Some(128_000);
297 }
298 return Some(8_192);
299 }
300 if m.contains("glm") {
302 return Some(128_000);
303 }
304 Some(128_000)
307}