1use async_trait::async_trait;
8use futures::StreamExt;
9use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
10use tokio::sync::mpsc;
11use tracing::debug;
12
13use super::message::{ContentBlock, Message, StopReason, Usage};
14use super::provider::{Provider, ProviderError, ProviderRequest};
15use super::stream::StreamEvent;
16
17pub struct OpenAiProvider {
19 http: reqwest::Client,
20 base_url: String,
21 api_key: String,
22}
23
24impl OpenAiProvider {
25 pub fn new(base_url: &str, api_key: &str) -> Self {
26 let http = reqwest::Client::builder()
27 .timeout(std::time::Duration::from_secs(300))
28 .build()
29 .expect("failed to build HTTP client");
30
31 Self {
32 http,
33 base_url: base_url.trim_end_matches('/').to_string(),
34 api_key: api_key.to_string(),
35 }
36 }
37
38 fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
40 let mut messages = Vec::new();
43
44 if !request.system_prompt.is_empty() {
46 messages.push(serde_json::json!({
47 "role": "system",
48 "content": request.system_prompt,
49 }));
50 }
51
52 for msg in &request.messages {
54 match msg {
55 Message::User(u) => {
56 let content = blocks_to_openai_content(&u.content);
57 messages.push(serde_json::json!({
58 "role": "user",
59 "content": content,
60 }));
61 }
62 Message::Assistant(a) => {
63 let mut msg_json = serde_json::json!({
64 "role": "assistant",
65 });
66
67 let tool_calls: Vec<serde_json::Value> = a
69 .content
70 .iter()
71 .filter_map(|b| match b {
72 ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
73 "id": id,
74 "type": "function",
75 "function": {
76 "name": name,
77 "arguments": serde_json::to_string(input).unwrap_or_default(),
78 }
79 })),
80 _ => None,
81 })
82 .collect();
83
84 let text: String = a
86 .content
87 .iter()
88 .filter_map(|b| match b {
89 ContentBlock::Text { text } => Some(text.as_str()),
90 _ => None,
91 })
92 .collect::<Vec<_>>()
93 .join("");
94
95 msg_json["content"] = serde_json::Value::String(text);
97 if !tool_calls.is_empty() {
98 msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
99 }
100
101 messages.push(msg_json);
102 }
103 Message::System(_) => {} }
105 }
106
107 let mut final_messages = Vec::new();
110 for msg in messages {
111 if msg.get("role").and_then(|r| r.as_str()) == Some("user") {
112 if let Some(content) = msg.get("content")
114 && let Some(arr) = content.as_array()
115 {
116 let mut tool_results = Vec::new();
117 let mut other_content = Vec::new();
118
119 for block in arr {
120 if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
121 tool_results.push(serde_json::json!({
122 "role": "tool",
123 "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
124 "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
125 }));
126 } else {
127 other_content.push(block.clone());
128 }
129 }
130
131 if !tool_results.is_empty() {
132 for tr in tool_results {
134 final_messages.push(tr);
135 }
136 if !other_content.is_empty() {
137 let mut m = msg.clone();
138 m["content"] = serde_json::Value::Array(other_content);
139 final_messages.push(m);
140 }
141 continue;
142 }
143 }
144 }
145 final_messages.push(msg);
146 }
147
148 let tools: Vec<serde_json::Value> = request
150 .tools
151 .iter()
152 .map(|t| {
153 serde_json::json!({
154 "type": "function",
155 "function": {
156 "name": t.name,
157 "description": t.description,
158 "parameters": t.input_schema,
159 }
160 })
161 })
162 .collect();
163
164 let model_lower = request.model.to_lowercase();
166 let uses_new_token_param = model_lower.starts_with("o1")
167 || model_lower.starts_with("o3")
168 || model_lower.contains("gpt-5")
169 || model_lower.contains("gpt-4.1");
170
171 let mut body = serde_json::json!({
172 "model": request.model,
173 "messages": final_messages,
174 "stream": true,
175 "stream_options": { "include_usage": true },
176 });
177
178 if uses_new_token_param {
179 body["max_completion_tokens"] = serde_json::json!(request.max_tokens);
180 } else {
181 body["max_tokens"] = serde_json::json!(request.max_tokens);
182 }
183
184 if !tools.is_empty() {
185 body["tools"] = serde_json::Value::Array(tools);
186
187 use super::provider::ToolChoice;
189 match &request.tool_choice {
190 ToolChoice::Auto => {
191 body["tool_choice"] = serde_json::json!("auto");
192 }
193 ToolChoice::Any => {
194 body["tool_choice"] = serde_json::json!("required");
195 }
196 ToolChoice::None => {
197 body["tool_choice"] = serde_json::json!("none");
198 }
199 ToolChoice::Specific(name) => {
200 body["tool_choice"] = serde_json::json!({
201 "type": "function",
202 "function": { "name": name }
203 });
204 }
205 }
206 }
207 if let Some(temp) = request.temperature {
208 body["temperature"] = serde_json::json!(temp);
209 }
210
211 body
212 }
213}
214
215#[async_trait]
216impl Provider for OpenAiProvider {
217 fn name(&self) -> &str {
218 "openai"
219 }
220
221 async fn stream(
222 &self,
223 request: &ProviderRequest,
224 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
225 let url = format!("{}/chat/completions", self.base_url);
226 let body = self.build_body(request);
227
228 let mut headers = HeaderMap::new();
229 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
230 headers.insert(
231 AUTHORIZATION,
232 HeaderValue::from_str(&format!("Bearer {}", self.api_key))
233 .map_err(|e| ProviderError::Auth(e.to_string()))?,
234 );
235
236 debug!("OpenAI request to {url}");
237
238 let response = self
239 .http
240 .post(&url)
241 .headers(headers)
242 .json(&body)
243 .send()
244 .await
245 .map_err(|e| ProviderError::Network(e.to_string()))?;
246
247 let status = response.status();
248 if !status.is_success() {
249 let body_text = response.text().await.unwrap_or_default();
250 return match status.as_u16() {
251 401 | 403 => Err(ProviderError::Auth(body_text)),
252 429 => Err(ProviderError::RateLimited {
253 retry_after_ms: 1000,
254 }),
255 529 => Err(ProviderError::Overloaded),
256 413 => Err(ProviderError::RequestTooLarge(body_text)),
257 _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
258 };
259 }
260
261 let (tx, rx) = mpsc::channel(64);
263 let cancel = request.cancel.clone();
264 tokio::spawn(async move {
265 let mut byte_stream = response.bytes_stream();
266 let mut buffer = String::new();
267 let mut current_tool_id = String::new();
268 let mut current_tool_name = String::new();
269 let mut current_tool_args = String::new();
270 let mut usage = Usage::default();
271 let mut stop_reason: Option<StopReason> = None;
272
273 loop {
274 let chunk_result = tokio::select! {
278 biased;
279 _ = cancel.cancelled() => return,
280 chunk = byte_stream.next() => match chunk {
281 Some(c) => c,
282 None => break,
283 },
284 };
285 let chunk = match chunk_result {
286 Ok(c) => c,
287 Err(e) => {
288 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
289 break;
290 }
291 };
292
293 buffer.push_str(&String::from_utf8_lossy(&chunk));
294
295 while let Some(pos) = buffer.find("\n\n") {
296 let event_text = buffer[..pos].to_string();
297 buffer = buffer[pos + 2..].to_string();
298
299 for line in event_text.lines() {
300 let data = if let Some(d) = line.strip_prefix("data: ") {
301 d
302 } else {
303 continue;
304 };
305
306 if data == "[DONE]" {
307 if !current_tool_id.is_empty() {
309 let input: serde_json::Value =
310 serde_json::from_str(¤t_tool_args).unwrap_or_default();
311 let _ = tx
312 .send(StreamEvent::ContentBlockComplete(
313 ContentBlock::ToolUse {
314 id: current_tool_id.clone(),
315 name: current_tool_name.clone(),
316 input,
317 },
318 ))
319 .await;
320 current_tool_id.clear();
321 current_tool_name.clear();
322 current_tool_args.clear();
323 }
324
325 let _ = tx
326 .send(StreamEvent::Done {
327 usage: usage.clone(),
328 stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
329 })
330 .await;
331 return;
332 }
333
334 let parsed: serde_json::Value = match serde_json::from_str(data) {
335 Ok(v) => v,
336 Err(_) => continue,
337 };
338
339 let delta = match parsed
341 .get("choices")
342 .and_then(|c| c.get(0))
343 .and_then(|c| c.get("delta"))
344 {
345 Some(d) => d,
346 None => {
347 if let Some(u) = parsed.get("usage") {
349 usage.input_tokens = u
350 .get("prompt_tokens")
351 .and_then(|v| v.as_u64())
352 .unwrap_or(0);
353 usage.output_tokens = u
354 .get("completion_tokens")
355 .and_then(|v| v.as_u64())
356 .unwrap_or(0);
357 }
358 continue;
359 }
360 };
361
362 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
364 && !content.is_empty()
365 {
366 debug!("OpenAI text delta: {}", &content[..content.len().min(80)]);
367 let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
368 }
369
370 if let Some(finish) = parsed
372 .get("choices")
373 .and_then(|c| c.get(0))
374 .and_then(|c| c.get("finish_reason"))
375 .and_then(|f| f.as_str())
376 {
377 debug!("OpenAI finish_reason: {finish}");
378 match finish {
379 "stop" => {
380 stop_reason = Some(StopReason::EndTurn);
381 }
382 "tool_calls" => {
383 stop_reason = Some(StopReason::ToolUse);
384 }
385 "length" => {
386 stop_reason = Some(StopReason::MaxTokens);
387 }
388 _ => {}
389 }
390 }
391
392 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
394 {
395 for tc in tool_calls {
396 if let Some(func) = tc.get("function") {
397 if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
398 if !current_tool_id.is_empty()
400 && !current_tool_args.is_empty()
401 {
402 let input: serde_json::Value =
404 serde_json::from_str(¤t_tool_args)
405 .unwrap_or_default();
406 let _ = tx
407 .send(StreamEvent::ContentBlockComplete(
408 ContentBlock::ToolUse {
409 id: current_tool_id.clone(),
410 name: current_tool_name.clone(),
411 input,
412 },
413 ))
414 .await;
415 }
416 current_tool_id = tc
417 .get("id")
418 .and_then(|i| i.as_str())
419 .unwrap_or("")
420 .to_string();
421 current_tool_name = name.to_string();
422 current_tool_args.clear();
423 }
424 if let Some(args) =
425 func.get("arguments").and_then(|a| a.as_str())
426 {
427 current_tool_args.push_str(args);
428 }
429 }
430 }
431 }
432 }
433 }
434 }
435
436 if !current_tool_id.is_empty() {
438 let input: serde_json::Value =
439 serde_json::from_str(¤t_tool_args).unwrap_or_default();
440 let _ = tx
441 .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
442 id: current_tool_id,
443 name: current_tool_name,
444 input,
445 }))
446 .await;
447 }
448
449 let _ = tx
450 .send(StreamEvent::Done {
451 usage,
452 stop_reason: Some(StopReason::EndTurn),
453 })
454 .await;
455 });
456
457 Ok(rx)
458 }
459}
460
461fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
463 if blocks.len() == 1
464 && let ContentBlock::Text { text } = &blocks[0]
465 {
466 return serde_json::Value::String(text.clone());
467 }
468
469 let parts: Vec<serde_json::Value> = blocks
470 .iter()
471 .map(|b| match b {
472 ContentBlock::Text { text } => serde_json::json!({
473 "type": "text",
474 "text": text,
475 }),
476 ContentBlock::Image { media_type, data } => serde_json::json!({
477 "type": "image_url",
478 "image_url": {
479 "url": format!("data:{media_type};base64,{data}"),
480 }
481 }),
482 ContentBlock::ToolResult {
483 tool_use_id,
484 content,
485 is_error,
486 ..
487 } => serde_json::json!({
488 "type": "tool_result",
489 "tool_use_id": tool_use_id,
490 "content": content,
491 "is_error": is_error,
492 }),
493 ContentBlock::Thinking { thinking, .. } => serde_json::json!({
494 "type": "text",
495 "text": thinking,
496 }),
497 ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
498 "type": "text",
499 "text": format!("[Tool call: {name}({input})]"),
500 }),
501 ContentBlock::Document { title, .. } => serde_json::json!({
502 "type": "text",
503 "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
504 }),
505 })
506 .collect();
507
508 serde_json::Value::Array(parts)
509}