1use async_trait::async_trait;
8use futures::StreamExt;
9use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
10use tokio::sync::mpsc;
11use tracing::debug;
12
13use super::message::{ContentBlock, Message, StopReason, Usage};
14use super::provider::{Provider, ProviderError, ProviderRequest};
15use super::stream::StreamEvent;
16
17pub struct OpenAiProvider {
19 http: reqwest::Client,
20 base_url: String,
21 api_key: String,
22}
23
24impl OpenAiProvider {
25 pub fn new(base_url: &str, api_key: &str) -> Self {
26 let http = reqwest::Client::builder()
27 .timeout(std::time::Duration::from_secs(300))
28 .build()
29 .expect("failed to build HTTP client");
30
31 Self {
32 http,
33 base_url: base_url.trim_end_matches('/').to_string(),
34 api_key: api_key.to_string(),
35 }
36 }
37
38 fn build_body(&self, request: &ProviderRequest) -> serde_json::Value {
40 let mut messages = Vec::new();
43
44 if !request.system_prompt.is_empty() {
46 messages.push(serde_json::json!({
47 "role": "system",
48 "content": request.system_prompt,
49 }));
50 }
51
52 for msg in &request.messages {
54 match msg {
55 Message::User(u) => {
56 let content = blocks_to_openai_content(&u.content);
57 messages.push(serde_json::json!({
58 "role": "user",
59 "content": content,
60 }));
61 }
62 Message::Assistant(a) => {
63 let mut msg_json = serde_json::json!({
64 "role": "assistant",
65 });
66
67 let tool_calls: Vec<serde_json::Value> = a
69 .content
70 .iter()
71 .filter_map(|b| match b {
72 ContentBlock::ToolUse { id, name, input } => Some(serde_json::json!({
73 "id": id,
74 "type": "function",
75 "function": {
76 "name": name,
77 "arguments": serde_json::to_string(input).unwrap_or_default(),
78 }
79 })),
80 _ => None,
81 })
82 .collect();
83
84 let text: String = a
86 .content
87 .iter()
88 .filter_map(|b| match b {
89 ContentBlock::Text { text } => Some(text.as_str()),
90 _ => None,
91 })
92 .collect::<Vec<_>>()
93 .join("");
94
95 msg_json["content"] = serde_json::Value::String(text);
97 if !tool_calls.is_empty() {
98 msg_json["tool_calls"] = serde_json::Value::Array(tool_calls);
99 }
100
101 messages.push(msg_json);
102 }
103 Message::System(_) => {} }
105 }
106
107 let mut final_messages = Vec::new();
110 for msg in messages {
111 if msg.get("role").and_then(|r| r.as_str()) == Some("user") {
112 if let Some(content) = msg.get("content")
114 && let Some(arr) = content.as_array()
115 {
116 let mut tool_results = Vec::new();
117 let mut other_content = Vec::new();
118
119 for block in arr {
120 if block.get("type").and_then(|t| t.as_str()) == Some("tool_result") {
121 tool_results.push(serde_json::json!({
122 "role": "tool",
123 "tool_call_id": block.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or(""),
124 "content": block.get("content").and_then(|v| v.as_str()).unwrap_or(""),
125 }));
126 } else {
127 other_content.push(block.clone());
128 }
129 }
130
131 if !tool_results.is_empty() {
132 for tr in tool_results {
134 final_messages.push(tr);
135 }
136 if !other_content.is_empty() {
137 let mut m = msg.clone();
138 m["content"] = serde_json::Value::Array(other_content);
139 final_messages.push(m);
140 }
141 continue;
142 }
143 }
144 }
145 final_messages.push(msg);
146 }
147
148 let tools: Vec<serde_json::Value> = request
150 .tools
151 .iter()
152 .map(|t| {
153 serde_json::json!({
154 "type": "function",
155 "function": {
156 "name": t.name,
157 "description": t.description,
158 "parameters": t.input_schema,
159 }
160 })
161 })
162 .collect();
163
164 let model_lower = request.model.to_lowercase();
166 let uses_new_token_param = model_lower.starts_with("o1")
167 || model_lower.starts_with("o3")
168 || model_lower.contains("gpt-5")
169 || model_lower.contains("gpt-4.1");
170
171 let mut body = serde_json::json!({
172 "model": request.model,
173 "messages": final_messages,
174 "stream": true,
175 "stream_options": { "include_usage": true },
176 });
177
178 if uses_new_token_param {
179 body["max_completion_tokens"] = serde_json::json!(request.max_tokens);
180 } else {
181 body["max_tokens"] = serde_json::json!(request.max_tokens);
182 }
183
184 if !tools.is_empty() {
185 body["tools"] = serde_json::Value::Array(tools);
186
187 use super::provider::ToolChoice;
189 match &request.tool_choice {
190 ToolChoice::Auto => {
191 body["tool_choice"] = serde_json::json!("auto");
192 }
193 ToolChoice::Any => {
194 body["tool_choice"] = serde_json::json!("required");
195 }
196 ToolChoice::None => {
197 body["tool_choice"] = serde_json::json!("none");
198 }
199 ToolChoice::Specific(name) => {
200 body["tool_choice"] = serde_json::json!({
201 "type": "function",
202 "function": { "name": name }
203 });
204 }
205 }
206 }
207 if let Some(temp) = request.temperature {
208 body["temperature"] = serde_json::json!(temp);
209 }
210
211 body
212 }
213}
214
215#[async_trait]
216impl Provider for OpenAiProvider {
217 fn name(&self) -> &str {
218 "openai"
219 }
220
221 async fn stream(
222 &self,
223 request: &ProviderRequest,
224 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
225 let url = format!("{}/chat/completions", self.base_url);
226 let body = self.build_body(request);
227
228 let mut headers = HeaderMap::new();
229 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
230 headers.insert(
231 AUTHORIZATION,
232 HeaderValue::from_str(&format!("Bearer {}", self.api_key))
233 .map_err(|e| ProviderError::Auth(e.to_string()))?,
234 );
235
236 debug!("OpenAI request to {url}");
237
238 let response = self
239 .http
240 .post(&url)
241 .headers(headers)
242 .json(&body)
243 .send()
244 .await
245 .map_err(|e| ProviderError::Network(e.to_string()))?;
246
247 let status = response.status();
248 if !status.is_success() {
249 let body_text = response.text().await.unwrap_or_default();
250 return match status.as_u16() {
251 401 | 403 => Err(ProviderError::Auth(body_text)),
252 429 => Err(ProviderError::RateLimited {
253 retry_after_ms: 1000,
254 }),
255 529 => Err(ProviderError::Overloaded),
256 413 => Err(ProviderError::RequestTooLarge(body_text)),
257 _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
258 };
259 }
260
261 let (tx, rx) = mpsc::channel(64);
263 tokio::spawn(async move {
264 let mut byte_stream = response.bytes_stream();
265 let mut buffer = String::new();
266 let mut current_tool_id = String::new();
267 let mut current_tool_name = String::new();
268 let mut current_tool_args = String::new();
269 let mut usage = Usage::default();
270 let mut stop_reason: Option<StopReason> = None;
271
272 while let Some(chunk_result) = byte_stream.next().await {
273 let chunk = match chunk_result {
274 Ok(c) => c,
275 Err(e) => {
276 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
277 break;
278 }
279 };
280
281 buffer.push_str(&String::from_utf8_lossy(&chunk));
282
283 while let Some(pos) = buffer.find("\n\n") {
284 let event_text = buffer[..pos].to_string();
285 buffer = buffer[pos + 2..].to_string();
286
287 for line in event_text.lines() {
288 let data = if let Some(d) = line.strip_prefix("data: ") {
289 d
290 } else {
291 continue;
292 };
293
294 if data == "[DONE]" {
295 if !current_tool_id.is_empty() {
297 let input: serde_json::Value =
298 serde_json::from_str(¤t_tool_args).unwrap_or_default();
299 let _ = tx
300 .send(StreamEvent::ContentBlockComplete(
301 ContentBlock::ToolUse {
302 id: current_tool_id.clone(),
303 name: current_tool_name.clone(),
304 input,
305 },
306 ))
307 .await;
308 current_tool_id.clear();
309 current_tool_name.clear();
310 current_tool_args.clear();
311 }
312
313 let _ = tx
314 .send(StreamEvent::Done {
315 usage: usage.clone(),
316 stop_reason: stop_reason.clone().or(Some(StopReason::EndTurn)),
317 })
318 .await;
319 return;
320 }
321
322 let parsed: serde_json::Value = match serde_json::from_str(data) {
323 Ok(v) => v,
324 Err(_) => continue,
325 };
326
327 let delta = match parsed
329 .get("choices")
330 .and_then(|c| c.get(0))
331 .and_then(|c| c.get("delta"))
332 {
333 Some(d) => d,
334 None => {
335 if let Some(u) = parsed.get("usage") {
337 usage.input_tokens = u
338 .get("prompt_tokens")
339 .and_then(|v| v.as_u64())
340 .unwrap_or(0);
341 usage.output_tokens = u
342 .get("completion_tokens")
343 .and_then(|v| v.as_u64())
344 .unwrap_or(0);
345 }
346 continue;
347 }
348 };
349
350 if let Some(content) = delta.get("content").and_then(|c| c.as_str())
352 && !content.is_empty()
353 {
354 debug!("OpenAI text delta: {}", &content[..content.len().min(80)]);
355 let _ = tx.send(StreamEvent::TextDelta(content.to_string())).await;
356 }
357
358 if let Some(finish) = parsed
360 .get("choices")
361 .and_then(|c| c.get(0))
362 .and_then(|c| c.get("finish_reason"))
363 .and_then(|f| f.as_str())
364 {
365 debug!("OpenAI finish_reason: {finish}");
366 match finish {
367 "stop" => {
368 stop_reason = Some(StopReason::EndTurn);
369 }
370 "tool_calls" => {
371 stop_reason = Some(StopReason::ToolUse);
372 }
373 "length" => {
374 stop_reason = Some(StopReason::MaxTokens);
375 }
376 _ => {}
377 }
378 }
379
380 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array())
382 {
383 for tc in tool_calls {
384 if let Some(func) = tc.get("function") {
385 if let Some(name) = func.get("name").and_then(|n| n.as_str()) {
386 if !current_tool_id.is_empty()
388 && !current_tool_args.is_empty()
389 {
390 let input: serde_json::Value =
392 serde_json::from_str(¤t_tool_args)
393 .unwrap_or_default();
394 let _ = tx
395 .send(StreamEvent::ContentBlockComplete(
396 ContentBlock::ToolUse {
397 id: current_tool_id.clone(),
398 name: current_tool_name.clone(),
399 input,
400 },
401 ))
402 .await;
403 }
404 current_tool_id = tc
405 .get("id")
406 .and_then(|i| i.as_str())
407 .unwrap_or("")
408 .to_string();
409 current_tool_name = name.to_string();
410 current_tool_args.clear();
411 }
412 if let Some(args) =
413 func.get("arguments").and_then(|a| a.as_str())
414 {
415 current_tool_args.push_str(args);
416 }
417 }
418 }
419 }
420 }
421 }
422 }
423
424 if !current_tool_id.is_empty() {
426 let input: serde_json::Value =
427 serde_json::from_str(¤t_tool_args).unwrap_or_default();
428 let _ = tx
429 .send(StreamEvent::ContentBlockComplete(ContentBlock::ToolUse {
430 id: current_tool_id,
431 name: current_tool_name,
432 input,
433 }))
434 .await;
435 }
436
437 let _ = tx
438 .send(StreamEvent::Done {
439 usage,
440 stop_reason: Some(StopReason::EndTurn),
441 })
442 .await;
443 });
444
445 Ok(rx)
446 }
447}
448
449fn blocks_to_openai_content(blocks: &[ContentBlock]) -> serde_json::Value {
451 if blocks.len() == 1
452 && let ContentBlock::Text { text } = &blocks[0]
453 {
454 return serde_json::Value::String(text.clone());
455 }
456
457 let parts: Vec<serde_json::Value> = blocks
458 .iter()
459 .map(|b| match b {
460 ContentBlock::Text { text } => serde_json::json!({
461 "type": "text",
462 "text": text,
463 }),
464 ContentBlock::Image { media_type, data } => serde_json::json!({
465 "type": "image_url",
466 "image_url": {
467 "url": format!("data:{media_type};base64,{data}"),
468 }
469 }),
470 ContentBlock::ToolResult {
471 tool_use_id,
472 content,
473 is_error,
474 ..
475 } => serde_json::json!({
476 "type": "tool_result",
477 "tool_use_id": tool_use_id,
478 "content": content,
479 "is_error": is_error,
480 }),
481 ContentBlock::Thinking { thinking, .. } => serde_json::json!({
482 "type": "text",
483 "text": thinking,
484 }),
485 ContentBlock::ToolUse { name, input, .. } => serde_json::json!({
486 "type": "text",
487 "text": format!("[Tool call: {name}({input})]"),
488 }),
489 ContentBlock::Document { title, .. } => serde_json::json!({
490 "type": "text",
491 "text": format!("[Document: {}]", title.as_deref().unwrap_or("untitled")),
492 }),
493 })
494 .collect();
495
496 serde_json::Value::Array(parts)
497}