matrixcode_core/providers/
openai.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde_json::{Value, json};
4
5use crate::models::context_window_for;
6use crate::tools::ToolDefinition;
7
8use super::{
9 ChatRequest, ChatResponse, ContentBlock, Message, MessageContent, Provider, Role, StopReason,
10 Usage,
11};
12
13pub struct OpenAIProvider {
14 api_key: String,
15 model: String,
16 base_url: String,
17 client: reqwest::Client,
18}
19
20impl OpenAIProvider {
21 pub fn new(api_key: String, model: String, base_url: String) -> Self {
22 let client = reqwest::Client::builder()
23 .timeout(std::time::Duration::from_secs(120))
24 .connect_timeout(std::time::Duration::from_secs(10))
25 .build()
26 .unwrap_or_else(|_| reqwest::Client::new());
27 Self {
28 api_key,
29 model,
30 base_url,
31 client,
32 }
33 }
34
35 fn convert_messages(&self, messages: &[Message], system: Option<&str>) -> Vec<Value> {
36 let mut result = Vec::new();
37
38 if let Some(sys) = system {
39 result.push(json!({"role": "system", "content": sys}));
40 }
41
42 for msg in messages {
43 match (&msg.role, &msg.content) {
44 (Role::System, _) => {}
45 (Role::User, MessageContent::Text(text)) => {
46 result.push(json!({"role": "user", "content": text}));
47 }
48 (Role::Assistant, MessageContent::Text(text)) => {
49 result.push(json!({"role": "assistant", "content": text}));
50 }
51 (Role::Assistant, MessageContent::Blocks(blocks)) => {
52 let mut tool_calls = Vec::new();
53 let mut text_parts = Vec::new();
54
55 for block in blocks {
56 match block {
57 ContentBlock::Text { text } => text_parts.push(text.clone()),
58 ContentBlock::ToolUse { id, name, input } => {
59 tool_calls.push(json!({
60 "id": id,
61 "type": "function",
62 "function": {
63 "name": name,
64 "arguments": input.to_string(),
65 }
66 }));
67 }
68 ContentBlock::Thinking { .. } => {}
69 _ => {}
70 }
71 }
72
73 let mut msg_obj = json!({"role": "assistant"});
74 if !text_parts.is_empty() {
75 msg_obj["content"] = json!(text_parts.join("\n"));
76 }
77 if !tool_calls.is_empty() {
78 msg_obj["tool_calls"] = json!(tool_calls);
79 }
80 result.push(msg_obj);
81 }
82 (Role::Tool, MessageContent::Blocks(blocks)) => {
83 self.push_tool_results(blocks, &mut result);
84 }
85 (Role::User, MessageContent::Blocks(blocks)) => {
86 if blocks
88 .iter()
89 .any(|b| matches!(b, ContentBlock::ToolResult { .. }))
90 {
91 self.push_tool_results(blocks, &mut result);
93 } else {
94 let text: String = blocks
96 .iter()
97 .filter_map(|b| match b {
98 ContentBlock::Text { text } => Some(text.as_str()),
99 _ => None,
100 })
101 .collect::<Vec<_>>()
102 .join("\n");
103 result.push(json!({"role": "user", "content": text}));
104 }
105 }
106 _ => {}
107 }
108 }
109
110 result
111 }
112
113 fn push_tool_results(&self, blocks: &[ContentBlock], result: &mut Vec<Value>) {
115 for block in blocks {
116 if let ContentBlock::ToolResult {
117 tool_use_id,
118 content,
119 } = block
120 {
121 result.push(json!({
122 "role": "tool",
123 "tool_call_id": tool_use_id,
124 "content": content,
125 }));
126 }
127 }
128 }
129
130 fn convert_tools(&self, tools: &[ToolDefinition]) -> Vec<Value> {
131 tools
132 .iter()
133 .map(|t| {
134 json!({
135 "type": "function",
136 "function": {
137 "name": t.name,
138 "description": t.description,
139 "parameters": t.parameters,
140 }
141 })
142 })
143 .collect()
144 }
145}
146
147#[async_trait]
148impl Provider for OpenAIProvider {
149 fn context_size(&self) -> Option<u32> {
150 context_window_for(&self.model)
151 }
152
153 fn clone_box(&self) -> Box<dyn Provider> {
154 Box::new(Self {
155 api_key: self.api_key.clone(),
156 model: self.model.clone(),
157 base_url: self.base_url.clone(),
158 client: reqwest::Client::new(),
159 })
160 }
161
162 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
163 let messages = self.convert_messages(&request.messages, request.system.as_deref());
164
165 let mut body = json!({
166 "model": self.model,
167 "messages": messages,
168 "max_completion_tokens": request.max_tokens,
169 });
170
171 if !request.tools.is_empty() {
172 body["tools"] = json!(self.convert_tools(&request.tools));
173 }
174
175 let url = format!("{}/chat/completions", self.base_url);
176 let response = self
177 .client
178 .post(&url)
179 .header("Authorization", format!("Bearer {}", self.api_key))
180 .header("Content-Type", "application/json")
181 .json(&body)
182 .send()
183 .await?;
184
185 let status = response.status();
186 let response_body: Value = response.json().await?;
187
188 if !status.is_success() {
189 let err_msg = response_body["error"]["message"]
190 .as_str()
191 .unwrap_or("unknown error");
192 anyhow::bail!("OpenAI API error ({}): {}", status, err_msg);
193 }
194
195 let choice = &response_body["choices"][0];
196 let message = &choice["message"];
197 let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop");
198
199 let stop_reason = match finish_reason {
200 "tool_calls" => StopReason::ToolUse,
201 "length" => StopReason::MaxTokens,
202 _ => StopReason::EndTurn,
203 };
204
205 let mut content = Vec::new();
206
207 let usage_blob = &response_body["usage"];
208 let usage = Usage {
209 input_tokens: usage_blob["prompt_tokens"].as_u64().unwrap_or(0) as u32,
210 output_tokens: usage_blob["completion_tokens"].as_u64().unwrap_or(0) as u32,
211 cache_creation_input_tokens: 0,
212 cache_read_input_tokens: usage_blob["prompt_tokens_details"]["cached_tokens"]
213 .as_u64()
214 .unwrap_or(0) as u32,
215 };
216
217 if let Some(text) = message["content"].as_str()
218 && !text.is_empty()
219 {
220 content.push(ContentBlock::Text {
221 text: text.to_string(),
222 });
223 }
224
225 if let Some(tool_calls) = message["tool_calls"].as_array() {
226 for tc in tool_calls {
227 let id = tc["id"].as_str().unwrap_or_default().to_string();
228 let name = tc["function"]["name"]
229 .as_str()
230 .unwrap_or_default()
231 .to_string();
232 let arguments = tc["function"]["arguments"].as_str().unwrap_or("{}");
233 let input: Value = serde_json::from_str(arguments).unwrap_or(json!({}));
234
235 content.push(ContentBlock::ToolUse { id, name, input });
236 }
237
238 if stop_reason == StopReason::EndTurn && !tool_calls.is_empty() {
239 return Ok(ChatResponse {
240 content,
241 stop_reason: StopReason::ToolUse,
242 usage: usage.clone(),
243 });
244 }
245 }
246
247 Ok(ChatResponse {
248 content,
249 stop_reason,
250 usage,
251 })
252 }
253}