1use async_trait::async_trait;
8use futures::StreamExt;
9use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
10use tokio::sync::mpsc;
11use tracing::debug;
12
13use super::message::{ContentBlock, Message, StopReason, Usage};
14use super::provider::{Provider, ProviderError, ProviderRequest};
15use super::stream::StreamEvent;
16
17pub struct OpenAiProvider {
18 http: reqwest::Client,
19 base_url: String,
20 api_key: String,
21}
22
23impl OpenAiProvider {
24 pub fn new(base_url: &str, api_key: &str) -> Self {
25 let http = reqwest::Client::builder()
26 .timeout(std::time::Duration::from_secs(300))
27 .build()
28 .expect("failed to build HTTP client");
29
30 Self {
31 http,
32 base_url: base_url.trim_end_matches('/').to_string(),
33 api_key: api_key.to_string(),
34 }
35 }
36
37 fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
39 let mut messages = Vec::new();
42
43 if !request.system_prompt.is_empty() {
45 messages.push(serde_json::json!({
46 "role": "system",
47 "content": request.system_prompt,
48 }));
49 }
50
51 for msg in &request.messages {
53 match msg {
54 Message::User(u) => {
55 let content = blocks_to_openai_content(&u.content);
56 messages.push(serde_json::json!({
57 "role": "user",
58 "content": content,
59 }));
60 }
61 Message::Assistant(a) => {
62 let mut msg_json = serde_json::json!({
63 "role": "assistant",
64 });
65
66 let tool_calls: Vec<serde_json::Value> = a
68 .content
69 .iter()
70 .filter_map(|b| match b {
71 ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
72 "id": id,
73 "type": "function",
74 "function": {
75 "name": name,
76 "arguments": serde_json::to_string(input).unwrap_or_default(),
77 }
78 })),
79 _ => None,
80 })
81 .collect();
82
83 let text: String = a
85 .content
86 .iter()
87 .filter_map(|b| match b {
88 ContentBlock::Text { text } => Some(text.as_str()),
89 _ => None,
90 })
91 .collect::<Vec<_>>()
92 .join("");
93
94 msg_json["content"] = serde_json::Value::String(text);
96 if !tool_calls.is_empty() {
97 msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
98 }
99
100 messages.push(msg_json);
101 }
102 Message::System(_) => {} }
104 }
105
106 let mut final_messages = Vec::new();
109 for msg in messages {
110 if msg.get("role").and_then(|r| r.as_str()) == Some("user") {
111 if let Some(content) = msg.get("content")
113 && let Some(arr) = content.as_array()
114 {
115 let mut tool_results = Vec::new();
116 let mut other_content = Vec::new();
117
118 for block in arr {
119 if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
120 tool_results.push(serde_json::json!({
121 "role": "tool",
122 "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
123 "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
124 }));
125 } else {
126 other_content.push(block.clone());
127 }
128 }
129
130 if !tool_results.is_empty() {
131 for tr in tool_results {
133 final_messages.push(tr);
134 }
135 if !other_content.is_empty() {
136 let mut m = msg.clone();
137 m["content"] = serde_json::Value::Array(other_content);
138 final_messages.push(m);
139 }
140 continue;
141 }
142 }
143 }
144 final_messages.push(msg);
145 }
146
147 let tools: Vec<serde_json::Value> = request
149 .tools
150 .iter()
151 .map(|t| {
152 serde_json::json!({
153 "type": "function",
154 "function": {
155 "name": t.name,
156 "description": t.description,
157 "parameters": t.input_schema,
158 }
159 })
160 })
161 .collect();
162
163 let model_lower = request.model.to_lowercase();
165 let uses_new_token_param = model_lower.starts_with("o1")
166 || model_lower.starts_with("o3")
167 || model_lower.contains("gpt-5")
168 || model_lower.contains("gpt-4.1");
169
170 let mut body = serde_json::json!({
171 "model": request.model,
172 "messages": final_messages,
173 "stream": true,
174 "stream_options": { "include_usage": true },
175 });
176
177 if uses_new_token_param {
178 body["max_completion_tokens"] = serde_json::json!(request.max_tokens);
179 } else {
180 body["max_tokens"] = serde_json::json!(request.max_tokens);
181 }
182
183 if !tools.is_empty() {
184 body["tools"] = serde_json::Value::Array(tools);
185
186 use super::provider::ToolChoice;
188 match &request.tool_choice {
189 ToolChoice::Auto => {
190 body["tool_choice"] = serde_json::json!("auto");
191 }
192 ToolChoice::Any => {
193 body["tool_choice"] = serde_json::json!("required");
194 }
195 ToolChoice::None => {
196 body["tool_choice"] = serde_json::json!("none");
197 }
198 ToolChoice::Specific(name) => {
199 body["tool_choice"] = serde_json::json!({
200 "type": "function",
201 "function": { "name": name }
202 });
203 }
204 }
205 }
206 if let Some(temp) = request.temperature {
207 body["temperature"] = serde_json::json!(temp);
208 }
209
210 body
211 }
212}
213
214#[async_trait]
215impl Provider for OpenAiProvider {
216 fn name(&self) -> &str {
217 "openai"
218 }
219
220 async fn stream(
221 &self,
222 request: &ProviderRequest,
223 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
224 let url = format!("{}/chat/completions", self.base_url);
225 let body = self.build_body(request);
226
227 let mut headers = HeaderMap::new();
228 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
229 headers.insert(
230 AUTHORIZATION,
231 HeaderValue::from_str(&format!("Bearer {}", self.api_key))
232 .map_err(|e| ProviderError::Auth(e.to_string()))?,
233 );
234
235 debug!("OpenAI request to {url}");
236
237 let response = self
238 .http
239 .post(&url)
240 .headers(headers)
241 .json(&body)
242 .send()
243 .await
244 .map_err(|e| ProviderError::Network(e.to_string()))?;
245
246 let status = response.status();
247 if !status.is_success() {
248 let body_text = response.text().await.unwrap_or_default();
249 return match status.as_u16() {
250 401 | 403 => Err(ProviderError::Auth(body_text)),
251 429 => Err(ProviderError::RateLimited {
252 retry_after_ms: 1000,
253 }),
254 529 => Err(ProviderError::Overloaded),
255 413 => Err(ProviderError::RequestTooLarge(body_text)),
256 _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
257 };
258 }
259
260 let (tx, rx) = mpsc::channel(64);
262 tokio::spawn(async move {
263 let mut byte_stream = response.bytes_stream();
264 let mut buffer = String::new();
265 let mut current_tool_id = String::new();
266 let mut current_tool_name = String::new();
267 let mut current_tool_args = String::new();
268 let mut usage = Usage::default();
269 let mut stop_reason: Option<StopReason> = None;
270
271 while let Some(chunk_result) = byte_stream.next().await {
272 let chunk = match chunk_result {
273 Ok(c) => c,
274 Err(e) => {
275 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
276 break;
277 }
278 };
279
280 buffer.push_str(&String::from_utf8_lossy(&chunk));
281
282 while let Some(pos) = buffer.find("\n\n") {
283 let event_text = buffer[..pos].to_string();
284 buffer = buffer[pos + 2..].to_string();
285
286 for line in event_text.lines() {
287 let data = if let Some(d) = line.strip_prefix("data: ") {
288 d
289 } else {
290 continue;
291 };
292
293 if data == "[DONE]" {
294 if !current_tool_id.is_empty() {
296 let input: serde_json::Value =
297 serde_json::from_str(¤t_tool_args).unwrap_or_default();
298 let _ = tx
299 .send(StreamEvent::ContentBlockComplete(
300 ContentBlock::ToolUse {
301 id: current_tool_id.clone(),
302 name: current_tool_name.clone(),
303 input,
304 },
305 ))
306 .await;
307 current_tool_id.clear();
308 current_tool_name.clear();
309 current_tool_args.clear();
310 }
311
312 let _ = tx
313 .send(StreamEvent::Done {
314 usage: usage.clone(),
315 stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
316 })
317 .await;
318 return;
319 }
320
321 let parsed: serde_json::Value = match serde_json::from_str(data) {
322 Ok(v) => v,
323 Err(_) => continue,
324 };
325
326 let delta = match parsed
328 .get("choices")
329 .and_then(|c| c.get(0))
330 .and_then(|c| c.get("delta"))
331 {
332 Some(d) => d,
333 None => {
334 if let Some(u) = parsed.get("usage") {
336 usage.input_tokens = u
337 .get("prompt_tokens")
338 .and_then(|v| v.as_u64())
339 .unwrap_or(0);
340 usage.output_tokens = u
341 .get("completion_tokens")
342 .and_then(|v| v.as_u64())
343 .unwrap_or(0);
344 }
345 continue;
346 }
347 };
348
349 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
351 && !content.is_empty()
352 {
353 debug!("OpenAI text delta: {}", &content[..content.len().min(80)]);
354 let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
355 }
356
357 if let Some(finish) = parsed
359 .get("choices")
360 .and_then(|c| c.get(0))
361 .and_then(|c| c.get("finish_reason"))
362 .and_then(|f| f.as_str())
363 {
364 debug!("OpenAI finish_reason: {finish}");
365 match finish {
366 "stop" => {
367 stop_reason = Some(StopReason::EndTurn);
368 }
369 "tool_calls" => {
370 stop_reason = Some(StopReason::ToolUse);
371 }
372 "length" => {
373 stop_reason = Some(StopReason::MaxTokens);
374 }
375 _ => {}
376 }
377 }
378
379 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
381 {
382 for tc in tool_calls {
383 if let Some(func) = tc.get("function") {
384 if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
385 if !current_tool_id.is_empty()
387 && !current_tool_args.is_empty()
388 {
389 let input: serde_json::Value =
391 serde_json::from_str(¤t_tool_args)
392 .unwrap_or_default();
393 let _ = tx
394 .send(StreamEvent::ContentBlockComplete(
395 ContentBlock::ToolUse {
396 id: current_tool_id.clone(),
397 name: current_tool_name.clone(),
398 input,
399 },
400 ))
401 .await;
402 }
403 current_tool_id = tc
404 .get("id")
405 .and_then(|i| i.as_str())
406 .unwrap_or("")
407 .to_string();
408 current_tool_name = name.to_string();
409 current_tool_args.clear();
410 }
411 if let Some(args) =
412 func.get("arguments").and_then(|a| a.as_str())
413 {
414 current_tool_args.push_str(args);
415 }
416 }
417 }
418 }
419 }
420 }
421 }
422
423 if !current_tool_id.is_empty() {
425 let input: serde_json::Value =
426 serde_json::from_str(¤t_tool_args).unwrap_or_default();
427 let _ = tx
428 .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
429 id: current_tool_id,
430 name: current_tool_name,
431 input,
432 }))
433 .await;
434 }
435
436 let _ = tx
437 .send(StreamEvent::Done {
438 usage,
439 stop_reason: Some(StopReason::EndTurn),
440 })
441 .await;
442 });
443
444 Ok(rx)
445 }
446}
447
448fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
450 if blocks.len() == 1
451 && let ContentBlock::Text { text } = &blocks[0]
452 {
453 return serde_json::Value::String(text.clone());
454 }
455
456 let parts: Vec<serde_json::Value> = blocks
457 .iter()
458 .map(|b| match b {
459 ContentBlock::Text { text } => serde_json::json!({
460 "type": "text",
461 "text": text,
462 }),
463 ContentBlock::Image { media_type, data } => serde_json::json!({
464 "type": "image_url",
465 "image_url": {
466 "url": format!("data:{media_type};base64,{data}"),
467 }
468 }),
469 ContentBlock::ToolResult {
470 tool_use_id,
471 content,
472 is_error,
473 ..
474 } => serde_json::json!({
475 "type": "tool_result",
476 "tool_use_id": tool_use_id,
477 "content": content,
478 "is_error": is_error,
479 }),
480 ContentBlock::Thinking { thinking, .. } => serde_json::json!({
481 "type": "text",
482 "text": thinking,
483 }),
484 ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
485 "type": "text",
486 "text": format!("[Tool call: {name}({input})]"),
487 }),
488 ContentBlock::Document { title, .. } => serde_json::json!({
489 "type": "text",
490 "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
491 }),
492 })
493 .collect();
494
495 serde_json::Value::Array(parts)
496}