matrixcode_core/providers/
openai.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::models::context_window_for;
6use crate::tools::ToolDefinition;
7
8use super::{
9 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
10 Usage,
11};
12
13pub struct OpenAIProvider {
14 api_key: String,
15 model: String,
16 base_url: String,
17 client: reqwest::Client,
18 extra_headers: Vec<(String, String)>,
20}
21
22impl OpenAIProvider {
23 pub fn new(api_key: String, model: String, base_url: String) -> Self {
24 Self::with_headers(api_key, model, base_url, None)
25 }
26
27 pub fn with_headers(
28 api_key: String,
29 model: String,
30 base_url: String,
31 extra_headers: Option<std::collections::HashMap<String, String>>,
32 ) -> Self {
33 let client = reqwest::Client::builder()
34 .timeout(std::time::Duration::from_secs(120))
35 .connect_timeout(std::time::Duration::from_secs(10))
36 .build()
37 .unwrap_or_else(|_| reqwest::Client::new());
38 let extra_headers: Vec<(String, String)> = extra_headers
39 .map(|h| h.into_iter().collect())
40 .unwrap_or_default();
41 Self {
42 api_key,
43 model,
44 base_url,
45 client,
46 extra_headers,
47 }
48 }
49
50 fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
51 let mut result = Vec::new();
52
53 if let Some(sys) = system {
54 result.push(json!({"role": "system", "content": sys}));
55 }
56
57 for msg in messages {
58 match (&msg.role, &msg.content) {
59 (Role::System, _) => {}
60 (Role::User, MessageContent::Text(text)) => {
61 result.push(json!({"role": "user", "content": text}));
62 }
63 (Role::Assistant, MessageContent::Text(text)) => {
64 result.push(json!({"role": "assistant", "content": text}));
65 }
66 (Role::Assistant, MessageContent::Blocks(blocks)) => {
67 let mut tool_calls = Vec::new();
68 let mut text_parts = Vec::new();
69
70 for block in blocks {
71 match block {
72 ContentBlock::Text { text } => text_parts.push(text.clone()),
73 ContentBlock::ToolUse { id, name, input } => {
74 tool_calls.push(json!({
75 "id": id,
76 "type": "function",
77 "function": {
78 "name": name,
79 "arguments": input.to_string(),
80 }
81 }));
82 }
83 ContentBlock::Thinking { .. } => {}
84 _ => {}
85 }
86 }
87
88 let mut msg_obj = json!({"role": "assistant"});
89 if !text_parts.is_empty() {
90 msg_obj["content"] = json!(text_parts.join("\n"));
91 }
92 if !tool_calls.is_empty() {
93 msg_obj["tool_calls"] = json!(tool_calls);
94 }
95 result.push(msg_obj);
96 }
97 (Role::Tool, MessageContent::Blocks(blocks)) => {
98 self.push_tool_results(blocks, &mut result);
99 }
100 (Role::User, MessageContent::Blocks(blocks)) => {
101 if blocks
103 .iter()
104 .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
105 {
106 self.push_tool_results(blocks, &mut result);
108 } else {
109 let text: String = blocks
111 .iter()
112 .filter_map(|b| match b {
113 ContentBlock::Text { text } => Some(text.as_str()),
114 _ => None,
115 })
116 .collect::<Vec<_>>()
117 .join("\n");
118 result.push(json!({"role": "user", "content": text}));
119 }
120 }
121 _ => {}
122 }
123 }
124
125 result
126 }
127
128 fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
130 for block in blocks {
131 if let ContentBlock::ToolResult {
132 tool_use_id,
133 content,
134 } = block
135 {
136 result.push(json!({
137 "role": "tool",
138 "tool_call_id": tool_use_id,
139 "content": content,
140 }));
141 }
142 }
143 }
144
145 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
146 tools
147 .iter()
148 .map(|t| {
149 json!({
150 "type": "function",
151 "function": {
152 "name": t.name,
153 "description": t.description,
154 "parameters": t.parameters,
155 }
156 })
157 })
158 .collect()
159 }
160}
161
162#[async_trait]
163impl Provider for OpenAIProvider {
164 fn context_size(&self) -> Option<u32> {
165 context_window_for(&self.model)
166 }
167
168 fn clone_box(&self) -> Box<dyn Provider> {
169 Box::new(Self {
170 api_key: self.api_key.clone(),
171 model: self.model.clone(),
172 base_url: self.base_url.clone(),
173 client: reqwest::Client::new(),
174 extra_headers: self.extra_headers.clone(),
175 })
176 }
177
178 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
179 let messages = self.convert_messages(&request.messages, request.system.as_deref());
180
181 let mut body = json!({
182 "model": self.model,
183 "messages": messages,
184 "max_completion_tokens": request.max_tokens,
185 });
186
187 if !request.tools.is_empty() {
188 body["tools"] = json!(self.convert_tools(&request.tools));
189 }
190
191 let url = format!("{}/chat/completions", self.base_url);
192 let mut req = self
193 .client
194 .post(&url)
195 .header("Authorization", format!("Bearer {}", self.api_key))
196 .header("Content-Type", "application/json")
197 .json(&body);
198
199 for (name, value) in &self.extra_headers {
201 req = req.header(name, value);
202 }
203
204 let response = req.send().await?;
205
206 let status = response.status();
207 let response_body: Value = response.json().await?;
208
209 if !status.is_success() {
210 let err_msg = response_body["error"]["message"]
211 .as_str()
212 .unwrap_or("unknown error");
213 anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
214 }
215
216 let choice = &response_body["choices"][0];
217 let message = &choice["message"];
218 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
219
220 let stop_reason = match finish_reason {
221 "tool_calls" => StopReason::ToolUse,
222 "length" => StopReason::MaxTokens,
223 _ => StopReason::EndTurn,
224 };
225
226 let mut content = Vec::new();
227
228 let usage_blob = &response_body["usage"];
229 let usage = Usage {
230 input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
231 output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
232 cache_creation_input_tokens: 0,
233 cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
234 .as_u64()
235 .unwrap_or(0) as u32,
236 };
237
238 if let Some(text) = message["content"].as_str()
239 && !text.is_empty()
240 {
241 content.push(ContentBlock::Text {
242 text: text.to_string(),
243 });
244 }
245
246 if let Some(tool_calls) = message["tool_calls"].as_array() {
247 for tc in tool_calls {
248 let id = tc["id"].as_str().unwrap_or_default().to_string();
249 let name = tc["function"]["name"]
250 .as_str()
251 .unwrap_or_default()
252 .to_string();
253 let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
254 let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
255
256 content.push(ContentBlock::ToolUse { id, name, input });
257 }
258
259 if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
260 return Ok(ChatResponse {
261 content,
262 stop_reason: StopReason::ToolUse,
263 usage: usage.clone(),
264 });
265 }
266 }
267
268 Ok(ChatResponse {
269 content,
270 stop_reason,
271 usage,
272 })
273 }
274}