matrixcode_core/providers/
openai.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::tools::ToolDefinition;
6
7use super::{
8 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
9 Usage,
10};
11
12pub struct OpenAIProvider {
13 api_key: String,
14 model: String,
15 base_url: String,
16 client: reqwest::Client,
17}
18
19impl OpenAIProvider {
20 pub fn new(api_key: String, model: String, base_url: String) -> Self {
21 Self {
22 api_key,
23 model,
24 base_url,
25 client: reqwest::Client::new(),
26 }
27 }
28
29 fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
30 let mut result = Vec::new();
31
32 if let Some(sys) = system {
33 result.push(json!({"role": "system", "content": sys}));
34 }
35
36 for msg in messages {
37 match (&msg.role, &msg.content) {
38 (Role::System, _) => {}
39 (Role::User, MessageContent::Text(text)) => {
40 result.push(json!({"role": "user", "content": text}));
41 }
42 (Role::Assistant, MessageContent::Text(text)) => {
43 result.push(json!({"role": "assistant", "content": text}));
44 }
45 (Role::Assistant, MessageContent::Blocks(blocks)) => {
46 let mut tool_calls = Vec::new();
47 let mut text_parts = Vec::new();
48
49 for block in blocks {
50 match block {
51 ContentBlock::Text { text } => text_parts.push(text.clone()),
52 ContentBlock::ToolUse { id, name, input } => {
53 tool_calls.push(json!({
54 "id": id,
55 "type": "function",
56 "function": {
57 "name": name,
58 "arguments": input.to_string(),
59 }
60 }));
61 }
62 ContentBlock::Thinking { .. } => {}
63 _ => {}
64 }
65 }
66
67 let mut msg_obj = json!({"role": "assistant"});
68 if !text_parts.is_empty() {
69 msg_obj["content"] = json!(text_parts.join("\n"));
70 }
71 if !tool_calls.is_empty() {
72 msg_obj["tool_calls"] = json!(tool_calls);
73 }
74 result.push(msg_obj);
75 }
76 (Role::Tool, MessageContent::Blocks(blocks)) => {
77 for block in blocks {
78 if let ContentBlock::ToolResult { tool_use_id, content } = block {
79 result.push(json!({
80 "role": "tool",
81 "tool_call_id": tool_use_id,
82 "content": content,
83 }));
84 }
85 }
86 }
87 (Role::User, MessageContent::Blocks(blocks)) => {
88 let text: String = blocks
89 .iter()
90 .filter_map(|b| match b {
91 ContentBlock::Text { text } => Some(text.as_str()),
92 _ => None,
93 })
94 .collect::<Vec<_>>()
95 .join("\n");
96 result.push(json!({"role": "user", "content": text}));
97 }
98 _ => {}
99 }
100 }
101
102 result
103 }
104
105 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
106 tools
107 .iter()
108 .map(|t| {
109 json!({
110 "type": "function",
111 "function": {
112 "name": t.name,
113 "description": t.description,
114 "parameters": t.parameters,
115 }
116 })
117 })
118 .collect()
119 }
120}
121
122#[async_trait]
123impl Provider for OpenAIProvider {
124 fn context_size(&self) -> Option<u32> {
125 context_window_for(&self.model)
126 }
127
128 fn clone_box(&self) -> Box<dyn Provider> {
129 Box::new(Self {
130 api_key: self.api_key.clone(),
131 model: self.model.clone(),
132 base_url: self.base_url.clone(),
133 client: reqwest::Client::new(),
134 })
135 }
136
137 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
138 let messages = self.convert_messages(&request.messages, request.system.as_deref());
139
140 let mut body = json!({
141 "model": self.model,
142 "messages": messages,
143 "max_completion_tokens": request.max_tokens,
144 });
145
146 if !request.tools.is_empty() {
147 body["tools"] = json!(self.convert_tools(&request.tools));
148 }
149
150 let url = format!("{}/chat/completions", self.base_url);
151 let response = self
152 .client
153 .post(&url)
154 .header("Authorization", format!("Bearer {}", self.api_key))
155 .header("Content-Type", "application/json")
156 .json(&body)
157 .send()
158 .await?;
159
160 let status = response.status();
161 let response_body: Value = response.json().await?;
162
163 if !status.is_success() {
164 let err_msg = response_body["error"]["message"]
165 .as_str()
166 .unwrap_or("unknown error");
167 anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
168 }
169
170 let choice = &response_body["choices"][0];
171 let message = &choice["message"];
172 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
173
174 let stop_reason = match finish_reason {
175 "tool_calls" => StopReason::ToolUse,
176 "length" => StopReason::MaxTokens,
177 _ => StopReason::EndTurn,
178 };
179
180 let mut content = Vec::new();
181
182 let usage_blob = &response_body["usage"];
183 let usage = Usage {
184 input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
185 output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
186 cache_creation_input_tokens: 0,
187 cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
188 .as_u64()
189 .unwrap_or(0) as u32,
190 };
191
192 if let Some(text) = message["content"].as_str()
193 && !text.is_empty() {
194 content.push(ContentBlock::Text { text: text.to_string() });
195 }
196
197 if let Some(tool_calls) = message["tool_calls"].as_array() {
198 for tc in tool_calls {
199 let id = tc["id"].as_str().unwrap_or_default().to_string();
200 let name = tc["function"]["name"].as_str().unwrap_or_default().to_string();
201 let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
202 let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
203
204 content.push(ContentBlock::ToolUse { id, name, input });
205 }
206
207 if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
208 return Ok(ChatResponse {
209 content,
210 stop_reason: StopReason::ToolUse,
211 usage: usage.clone(),
212 });
213 }
214 }
215
216 Ok(ChatResponse {
217 content,
218 stop_reason,
219 usage,
220 })
221 }
222}
223
224fn context_window_for(model: &str) -> Option<u32> {
227 if let Ok(raw) = std::env::var("CONTEXT_SIZE")
228 && let Ok(n) = raw.trim().parse::<u32>()
229 && n > 0 {
230 return Some(n);
231 }
232 let m = model.to_ascii_lowercase();
233
234 if m.contains("gpt-4o") || m.contains("gpt-4-turbo") {
236 return Some(128_000);
237 }
238 if m.contains("gpt-4-32k") {
240 return Some(32_768);
241 }
242 if m.contains("gpt-4") && !m.contains("turbo") && !m.contains("o") {
243 return Some(8_192);
244 }
245 if m.contains("gpt-3.5-turbo-16k") {
247 return Some(16_384);
248 }
249 if m.contains("gpt-3.5") {
250 return Some(4_096);
251 }
252 if m.contains("o1") {
254 return Some(200_000);
255 }
256 if m.contains("deepseek") {
258 if m.contains("v3") || m.contains("r1") {
259 return Some(128_000);
260 }
261 return Some(64_000);
262 }
263 if m.contains("qwen") {
265 if m.contains("qwen-max") || m.contains("qwen2.5-72b") {
266 return Some(128_000);
267 }
268 if m.contains("qwen2") {
269 return Some(32_000);
270 }
271 return Some(32_000);
272 }
273 if m.contains("llama-3") || m.contains("llama3") {
275 if m.contains("70b") || m.contains("405b") {
276 return Some(128_000);
277 }
278 return Some(8_192);
279 }
280 if m.contains("glm") {
282 return Some(128_000);
283 }
284 Some(128_000)
287}