1use std::{collections::HashMap, future::Future, pin::Pin};
2
3use anyhow::Context;
4use async_openai::{
5 Client,
6 config::OpenAIConfig,
7 types::{
8 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
9 ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
10 ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText,
11 ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
12 ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
13 ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
14 ChatCompletionRequestUserMessageContentPart, ChatCompletionTool, ChatCompletionToolType,
15 CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionObject, ImageUrl,
16 },
17};
18use futures::StreamExt;
19use tokio::sync::mpsc;
20use tracing::{debug, warn};
21
22use crate::provider::{
23 ContentBlock, Message, Provider, Role, StopReason, StreamEvent, StreamEventType,
24 ToolDefinition, Usage,
25};
26
27pub struct OpenAIProvider {
28 client: Client<OpenAIConfig>,
29 model: String,
30 cached_models: std::sync::Mutex<Option<Vec<String>>>,
31}
32
33impl OpenAIProvider {
34 pub fn new(model: impl Into<String>) -> Self {
35 Self {
36 client: Client::new(),
37 model: model.into(),
38 cached_models: std::sync::Mutex::new(None),
39 }
40 }
41 pub fn new_with_config(config: OpenAIConfig, model: impl Into<String>) -> Self {
42 Self {
43 client: Client::with_config(config),
44 model: model.into(),
45 cached_models: std::sync::Mutex::new(None),
46 }
47 }
48}
49
50#[derive(Default)]
51struct ToolCallAccum {
52 id: String,
53 name: String,
54 arguments: String,
55 started: bool,
56}
57
58fn convert_messages(
59 messages: &[Message],
60 system: Option<&str>,
61) -> anyhow::Result<Vec<ChatCompletionRequestMessage>> {
62 let mut result: Vec<ChatCompletionRequestMessage> = Vec::new();
63
64 if let Some(sys) = system {
65 result.push(ChatCompletionRequestMessage::System(
66 ChatCompletionRequestSystemMessage {
67 content: ChatCompletionRequestSystemMessageContent::Text(sys.to_string()),
68 name: None,
69 },
70 ));
71 }
72
73 for msg in messages {
74 match msg.role {
75 Role::System => {
76 let text = extract_text_content(&msg.content);
77 result.push(ChatCompletionRequestMessage::System(
78 ChatCompletionRequestSystemMessage {
79 content: ChatCompletionRequestSystemMessageContent::Text(text),
80 name: None,
81 },
82 ));
83 }
84
85 Role::User => {
86 let mut tool_results: Vec<(String, String)> = Vec::new();
87 let mut texts: Vec<String> = Vec::new();
88 let mut images: Vec<(String, String)> = Vec::new();
89
90 for block in &msg.content {
91 match block {
92 ContentBlock::Text(t) => texts.push(t.clone()),
93 ContentBlock::Image { media_type, data } => {
94 images.push((media_type.clone(), data.clone()));
95 }
96 ContentBlock::ToolResult {
97 tool_use_id,
98 content,
99 ..
100 } => {
101 tool_results.push((tool_use_id.clone(), content.clone()));
102 }
103 _ => {}
104 }
105 }
106
107 for (id, content) in tool_results {
108 result.push(ChatCompletionRequestMessage::Tool(
109 ChatCompletionRequestToolMessage {
110 content: ChatCompletionRequestToolMessageContent::Text(content),
111 tool_call_id: id,
112 },
113 ));
114 }
115
116 if !images.is_empty() {
117 let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
118 if !texts.is_empty() {
119 parts.push(ChatCompletionRequestUserMessageContentPart::Text(
120 ChatCompletionRequestMessageContentPartText {
121 text: texts.join("\n"),
122 },
123 ));
124 }
125 for (media_type, data) in images {
126 parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
127 ChatCompletionRequestMessageContentPartImage {
128 image_url: ImageUrl {
129 url: format!("data:{};base64,{}", media_type, data),
130 detail: None,
131 },
132 },
133 ));
134 }
135 result.push(ChatCompletionRequestMessage::User(
136 ChatCompletionRequestUserMessage {
137 content: ChatCompletionRequestUserMessageContent::Array(parts),
138 name: None,
139 },
140 ));
141 } else if !texts.is_empty() {
142 result.push(ChatCompletionRequestMessage::User(
143 ChatCompletionRequestUserMessage {
144 content: ChatCompletionRequestUserMessageContent::Text(
145 texts.join("\n"),
146 ),
147 name: None,
148 },
149 ));
150 }
151 }
152
153 Role::Assistant => {
154 let mut text_parts: Vec<String> = Vec::new();
155 let mut tool_calls: Vec<ChatCompletionMessageToolCall> = Vec::new();
156
157 for block in &msg.content {
158 match block {
159 ContentBlock::Text(t) => text_parts.push(t.clone()),
160 ContentBlock::ToolUse { id, name, input } => {
161 tool_calls.push(ChatCompletionMessageToolCall {
162 id: id.clone(),
163 r#type: ChatCompletionToolType::Function,
164 function: FunctionCall {
165 name: name.clone(),
166 arguments: serde_json::to_string(input).unwrap_or_default(),
167 },
168 });
169 }
170 _ => {}
171 }
172 }
173
174 let content = if text_parts.is_empty() {
175 None
176 } else {
177 Some(ChatCompletionRequestAssistantMessageContent::Text(
178 text_parts.join("\n"),
179 ))
180 };
181
182 result.push(ChatCompletionRequestMessage::Assistant(
183 ChatCompletionRequestAssistantMessage {
184 content,
185 name: None,
186 tool_calls: if tool_calls.is_empty() {
187 None
188 } else {
189 Some(tool_calls)
190 },
191 refusal: None,
192 ..Default::default()
193 },
194 ));
195 }
196 }
197 }
198
199 Ok(result)
200}
201
202fn extract_text_content(blocks: &[ContentBlock]) -> String {
203 blocks
204 .iter()
205 .filter_map(|b| {
206 if let ContentBlock::Text(t) = b {
207 Some(t.as_str())
208 } else {
209 None
210 }
211 })
212 .collect::<Vec<_>>()
213 .join("\n")
214}
215
216fn convert_tools(tools: &[ToolDefinition]) -> Vec<ChatCompletionTool> {
217 tools
218 .iter()
219 .map(|t| ChatCompletionTool {
220 r#type: ChatCompletionToolType::Function,
221 function: FunctionObject {
222 name: t.name.clone(),
223 description: Some(t.description.clone()),
224 parameters: Some(t.input_schema.clone()),
225 strict: None,
226 },
227 })
228 .collect()
229}
230
231fn map_finish_reason(reason: &FinishReason) -> StopReason {
232 match reason {
233 FinishReason::Stop => StopReason::EndTurn,
234 FinishReason::Length => StopReason::MaxTokens,
235 FinishReason::ToolCalls | FinishReason::FunctionCall => StopReason::ToolUse,
236 FinishReason::ContentFilter => StopReason::StopSequence,
237 }
238}
239
240impl Provider for OpenAIProvider {
241 fn name(&self) -> &str {
242 "openai"
243 }
244
245 fn model(&self) -> &str {
246 &self.model
247 }
248
249 fn set_model(&mut self, model: String) {
250 self.model = model;
251 }
252
253 fn available_models(&self) -> Vec<String> {
254 let cache = self.cached_models.lock().unwrap();
255 cache.clone().unwrap_or_default()
256 }
257
258 fn context_window(&self) -> u32 {
259 0
260 }
261
262 fn fetch_context_window(
263 &self,
264 ) -> Pin<Box<dyn Future<Output = anyhow::Result<u32>> + Send + '_>> {
265 Box::pin(async move { Ok(0) })
266 }
267
268 fn fetch_models(
269 &self,
270 ) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<String>>> + Send + '_>> {
271 let client = self.client.clone();
272 Box::pin(async move {
273 {
274 let cache = self.cached_models.lock().unwrap();
275 if let Some(ref models) = *cache {
276 return Ok(models.clone());
277 }
278 }
279
280 let resp = client.models().list().await;
281
282 match resp {
283 Ok(list) => {
284 let mut models: Vec<String> = list
285 .data
286 .into_iter()
287 .map(|m| m.id)
288 .filter(|id| {
289 id.starts_with("gpt-")
290 || id.starts_with("o1")
291 || id.starts_with("o3")
292 || id.starts_with("o4")
293 })
294 .collect();
295 models.sort();
296 models.dedup();
297
298 if models.is_empty() {
299 return Err(anyhow::anyhow!(
300 "OpenAI models API returned no matching models"
301 ));
302 }
303
304 let mut cache = self.cached_models.lock().unwrap();
305 *cache = Some(models.clone());
306 Ok(models)
307 }
308 Err(e) => Err(anyhow::anyhow!("Failed to fetch OpenAI models: {e}")),
309 }
310 })
311 }
312
313 fn stream(
314 &self,
315 messages: &[Message],
316 system: Option<&str>,
317 tools: &[ToolDefinition],
318 max_tokens: u32,
319 thinking_budget: u32,
320 ) -> Pin<
321 Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
322 > {
323 self.stream_with_model(
324 &self.model,
325 messages,
326 system,
327 tools,
328 max_tokens,
329 thinking_budget,
330 )
331 }
332
333 fn stream_with_model(
334 &self,
335 model: &str,
336 messages: &[Message],
337 system: Option<&str>,
338 tools: &[ToolDefinition],
339 max_tokens: u32,
340 _thinking_budget: u32,
341 ) -> Pin<
342 Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
343 > {
344 let messages = messages.to_vec();
345 let system = system.map(String::from);
346 let tools = tools.to_vec();
347 let model = model.to_string();
348 let client = self.client.clone();
349
350 Box::pin(async move {
351 let converted_messages = convert_messages(&messages, system.as_deref())
352 .context("Failed to convert messages")?;
353 let converted_tools = convert_tools(&tools);
354
355 let request = CreateChatCompletionRequest {
356 model: model.clone(),
357 messages: converted_messages,
358 max_completion_tokens: Some(max_tokens),
359 stream: Some(true),
360 tools: if converted_tools.is_empty() {
361 None
362 } else {
363 Some(converted_tools)
364 },
365 temperature: Some(1.0),
366 ..Default::default()
367 };
368
369 let mut oai_stream = client
370 .chat()
371 .create_stream(request)
372 .await
373 .context("Failed to create OpenAI stream")?;
374
375 let (tx, rx) = mpsc::unbounded_channel::<StreamEvent>();
376 let tx_clone = tx.clone();
377
378 tokio::spawn(async move {
379 let mut tool_accum: HashMap<u32, ToolCallAccum> = HashMap::new();
380 let mut total_output_tokens: u32 = 0;
381 let mut final_stop_reason: Option<StopReason> = None;
382
383 let _ = tx_clone.send(StreamEvent {
384 event_type: StreamEventType::MessageStart,
385 });
386
387 while let Some(result) = oai_stream.next().await {
388 match result {
389 Err(e) => {
390 warn!("OpenAI stream error: {e}");
391 let _ = tx_clone.send(StreamEvent {
392 event_type: StreamEventType::Error(e.to_string()),
393 });
394 return;
395 }
396 Ok(response) => {
397 if let Some(usage) = response.usage {
398 total_output_tokens = usage.completion_tokens;
399 }
400
401 for choice in response.choices {
402 if let Some(reason) = &choice.finish_reason {
403 final_stop_reason = Some(map_finish_reason(reason));
404
405 if matches!(
406 reason,
407 FinishReason::ToolCalls | FinishReason::FunctionCall
408 ) {
409 for accum in tool_accum.values() {
410 if accum.started {
411 let _ = tx_clone.send(StreamEvent {
412 event_type: StreamEventType::ToolUseEnd,
413 });
414 }
415 }
416 tool_accum.clear();
417 }
418 }
419
420 let delta = choice.delta;
421
422 if let Some(content) = delta.content
423 && !content.is_empty()
424 {
425 let _ = tx_clone.send(StreamEvent {
426 event_type: StreamEventType::TextDelta(content),
427 });
428 }
429
430 if let Some(tool_call_chunks) = delta.tool_calls {
431 for chunk in tool_call_chunks {
432 let idx = chunk.index;
433 let entry = tool_accum.entry(idx).or_default();
434
435 if let Some(id) = chunk.id
436 && !id.is_empty()
437 {
438 entry.id = id;
439 }
440
441 if let Some(func) = chunk.function {
442 if let Some(name) = func.name
443 && !name.is_empty()
444 {
445 entry.name = name;
446 }
447
448 if !entry.started
449 && !entry.id.is_empty()
450 && !entry.name.is_empty()
451 {
452 let _ = tx_clone.send(StreamEvent {
453 event_type: StreamEventType::ToolUseStart {
454 id: entry.id.clone(),
455 name: entry.name.clone(),
456 },
457 });
458 entry.started = true;
459 debug!(
460 "OpenAI tool use start: id={} name={}",
461 entry.id, entry.name
462 );
463 }
464
465 if let Some(args) = func.arguments
466 && !args.is_empty()
467 {
468 entry.arguments.push_str(&args);
469 let _ = tx_clone.send(StreamEvent {
470 event_type: StreamEventType::ToolUseInputDelta(
471 args,
472 ),
473 });
474 }
475 }
476 }
477 }
478 }
479 }
480 }
481 }
482
483 for accum in tool_accum.values() {
484 if accum.started {
485 let _ = tx_clone.send(StreamEvent {
486 event_type: StreamEventType::ToolUseEnd,
487 });
488 }
489 }
490
491 let stop = final_stop_reason.unwrap_or(StopReason::EndTurn);
492 let _ = tx_clone.send(StreamEvent {
493 event_type: StreamEventType::MessageEnd {
494 stop_reason: stop,
495 usage: Usage {
496 input_tokens: 0,
497 output_tokens: total_output_tokens,
498 cache_read_tokens: 0,
499 cache_write_tokens: 0,
500 },
501 },
502 });
503 });
504
505 Ok(rx)
506 })
507 }
508}