1use std::{collections::HashMap, future::Future, pin::Pin};
2
3use anyhow::Context;
4use async_openai::{
5 Client,
6 config::OpenAIConfig,
7 types::{
8 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
9 ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
10 ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText,
11 ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
12 ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
13 ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
14 ChatCompletionRequestUserMessageContentPart, ChatCompletionTool, ChatCompletionToolType,
15 CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionObject, ImageUrl,
16 },
17};
18use futures::StreamExt;
19use tokio::sync::mpsc;
20use tracing::{debug, warn};
21
22use crate::provider::{
23 ContentBlock, Message, Provider, Role, StopReason, StreamEvent, StreamEventType,
24 ToolDefinition, Usage,
25};
26
27pub struct OpenAIProvider {
28 client: Client<OpenAIConfig>,
29 model: String,
30 cached_models: std::sync::Mutex<Option<Vec<String>>>,
31 context_windows: std::sync::Mutex<HashMap<String, u32>>,
32}
33
34fn known_context_window(model: &str) -> u32 {
35 if model.starts_with("o1") || model.starts_with("o3") || model.starts_with("o4") {
36 return 200_000;
37 }
38 if model.starts_with("gpt-4o")
39 || model.starts_with("gpt-4-turbo")
40 || model == "gpt-4-1106-preview"
41 {
42 return 128_000;
43 }
44 if model.starts_with("gpt-4") {
45 return 8_192;
46 }
47 if model.starts_with("gpt-3.5") {
48 return 16_385;
49 }
50 0
51}
52
53impl OpenAIProvider {
54 pub fn new(model: impl Into<String>) -> Self {
55 Self {
56 client: Client::new(),
57 model: model.into(),
58 cached_models: std::sync::Mutex::new(None),
59 context_windows: std::sync::Mutex::new(HashMap::new()),
60 }
61 }
62 pub fn new_with_config(config: OpenAIConfig, model: impl Into<String>) -> Self {
63 Self {
64 client: Client::with_config(config),
65 model: model.into(),
66 cached_models: std::sync::Mutex::new(None),
67 context_windows: std::sync::Mutex::new(HashMap::new()),
68 }
69 }
70}
71
72#[derive(Default)]
73struct ToolCallAccum {
74 id: String,
75 name: String,
76 arguments: String,
77 started: bool,
78}
79
80fn convert_messages(
81 messages: &[Message],
82 system: Option<&str>,
83) -> anyhow::Result<Vec<ChatCompletionRequestMessage>> {
84 let mut result: Vec<ChatCompletionRequestMessage> = Vec::new();
85
86 if let Some(sys) = system {
87 result.push(ChatCompletionRequestMessage::System(
88 ChatCompletionRequestSystemMessage {
89 content: ChatCompletionRequestSystemMessageContent::Text(sys.to_string()),
90 name: None,
91 },
92 ));
93 }
94
95 for msg in messages {
96 match msg.role {
97 Role::System => {
98 let text = extract_text_content(&msg.content);
99 result.push(ChatCompletionRequestMessage::System(
100 ChatCompletionRequestSystemMessage {
101 content: ChatCompletionRequestSystemMessageContent::Text(text),
102 name: None,
103 },
104 ));
105 }
106
107 Role::User => {
108 let mut tool_results: Vec<(String, String)> = Vec::new();
109 let mut texts: Vec<String> = Vec::new();
110 let mut images: Vec<(String, String)> = Vec::new();
111
112 for block in &msg.content {
113 match block {
114 ContentBlock::Text(t) => texts.push(t.clone()),
115 ContentBlock::Image { media_type, data } => {
116 images.push((media_type.clone(), data.clone()));
117 }
118 ContentBlock::ToolResult {
119 tool_use_id,
120 content,
121 ..
122 } => {
123 tool_results.push((tool_use_id.clone(), content.clone()));
124 }
125 _ => {}
126 }
127 }
128
129 for (id, content) in tool_results {
130 result.push(ChatCompletionRequestMessage::Tool(
131 ChatCompletionRequestToolMessage {
132 content: ChatCompletionRequestToolMessageContent::Text(content),
133 tool_call_id: id,
134 },
135 ));
136 }
137
138 if !images.is_empty() {
139 let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
140 if !texts.is_empty() {
141 parts.push(ChatCompletionRequestUserMessageContentPart::Text(
142 ChatCompletionRequestMessageContentPartText {
143 text: texts.join("\n"),
144 },
145 ));
146 }
147 for (media_type, data) in images {
148 parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
149 ChatCompletionRequestMessageContentPartImage {
150 image_url: ImageUrl {
151 url: format!("data:{};base64,{}", media_type, data),
152 detail: None,
153 },
154 },
155 ));
156 }
157 result.push(ChatCompletionRequestMessage::User(
158 ChatCompletionRequestUserMessage {
159 content: ChatCompletionRequestUserMessageContent::Array(parts),
160 name: None,
161 },
162 ));
163 } else if !texts.is_empty() {
164 result.push(ChatCompletionRequestMessage::User(
165 ChatCompletionRequestUserMessage {
166 content: ChatCompletionRequestUserMessageContent::Text(
167 texts.join("\n"),
168 ),
169 name: None,
170 },
171 ));
172 }
173 }
174
175 Role::Assistant => {
176 let mut text_parts: Vec<String> = Vec::new();
177 let mut tool_calls: Vec<ChatCompletionMessageToolCall> = Vec::new();
178
179 for block in &msg.content {
180 match block {
181 ContentBlock::Text(t) => text_parts.push(t.clone()),
182 ContentBlock::ToolUse { id, name, input } => {
183 tool_calls.push(ChatCompletionMessageToolCall {
184 id: id.clone(),
185 r#type: ChatCompletionToolType::Function,
186 function: FunctionCall {
187 name: name.clone(),
188 arguments: serde_json::to_string(input).unwrap_or_default(),
189 },
190 });
191 }
192 _ => {}
193 }
194 }
195
196 let content = if text_parts.is_empty() {
197 None
198 } else {
199 Some(ChatCompletionRequestAssistantMessageContent::Text(
200 text_parts.join("\n"),
201 ))
202 };
203
204 result.push(ChatCompletionRequestMessage::Assistant(
205 ChatCompletionRequestAssistantMessage {
206 content,
207 name: None,
208 tool_calls: if tool_calls.is_empty() {
209 None
210 } else {
211 Some(tool_calls)
212 },
213 refusal: None,
214 ..Default::default()
215 },
216 ));
217 }
218 }
219 }
220
221 Ok(result)
222}
223
224fn extract_text_content(blocks: &[ContentBlock]) -> String {
225 blocks
226 .iter()
227 .filter_map(|b| {
228 if let ContentBlock::Text(t) = b {
229 Some(t.as_str())
230 } else {
231 None
232 }
233 })
234 .collect::<Vec<_>>()
235 .join("\n")
236}
237
238fn convert_tools(tools: &[ToolDefinition]) -> Vec<ChatCompletionTool> {
239 tools
240 .iter()
241 .map(|t| ChatCompletionTool {
242 r#type: ChatCompletionToolType::Function,
243 function: FunctionObject {
244 name: t.name.clone(),
245 description: Some(t.description.clone()),
246 parameters: Some(t.input_schema.clone()),
247 strict: None,
248 },
249 })
250 .collect()
251}
252
253fn map_finish_reason(reason: &FinishReason) -> StopReason {
254 match reason {
255 FinishReason::Stop => StopReason::EndTurn,
256 FinishReason::Length => StopReason::MaxTokens,
257 FinishReason::ToolCalls | FinishReason::FunctionCall => StopReason::ToolUse,
258 FinishReason::ContentFilter => StopReason::StopSequence,
259 }
260}
261
262impl Provider for OpenAIProvider {
263 fn name(&self) -> &str {
264 "openai"
265 }
266
267 fn model(&self) -> &str {
268 &self.model
269 }
270
271 fn set_model(&mut self, model: String) {
272 self.model = model;
273 }
274
275 fn available_models(&self) -> Vec<String> {
276 let cache = self.cached_models.lock().unwrap();
277 cache.clone().unwrap_or_default()
278 }
279
280 fn context_window(&self) -> u32 {
281 let cw = self.context_windows.lock().unwrap();
282 cw.get(&self.model).copied().unwrap_or(0)
283 }
284
285 fn fetch_context_window(
286 &self,
287 ) -> Pin<Box<dyn Future<Output = anyhow::Result<u32>> + Send + '_>> {
288 Box::pin(async move {
289 {
290 let cw = self.context_windows.lock().unwrap();
291 if let Some(&v) = cw.get(&self.model) {
292 return Ok(v);
293 }
294 }
295 let v = known_context_window(&self.model);
296 if v > 0 {
297 let mut cw = self.context_windows.lock().unwrap();
298 cw.insert(self.model.clone(), v);
299 }
300 Ok(v)
301 })
302 }
303
304 fn fetch_models(
305 &self,
306 ) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<String>>> + Send + '_>> {
307 let client = self.client.clone();
308 Box::pin(async move {
309 {
310 let cache = self.cached_models.lock().unwrap();
311 if let Some(ref models) = *cache {
312 return Ok(models.clone());
313 }
314 }
315
316 let resp = client.models().list().await;
317
318 match resp {
319 Ok(list) => {
320 let mut models: Vec<String> = list
321 .data
322 .into_iter()
323 .map(|m| m.id)
324 .filter(|id| {
325 id.starts_with("gpt-")
326 || id.starts_with("o1")
327 || id.starts_with("o3")
328 || id.starts_with("o4")
329 })
330 .collect();
331 models.sort();
332 models.dedup();
333
334 if models.is_empty() {
335 return Err(anyhow::anyhow!(
336 "OpenAI models API returned no matching models"
337 ));
338 }
339
340 let mut cache = self.cached_models.lock().unwrap();
341 *cache = Some(models.clone());
342 Ok(models)
343 }
344 Err(e) => Err(anyhow::anyhow!("Failed to fetch OpenAI models: {e}")),
345 }
346 })
347 }
348
349 fn stream(
350 &self,
351 messages: &[Message],
352 system: Option<&str>,
353 tools: &[ToolDefinition],
354 max_tokens: u32,
355 thinking_budget: u32,
356 ) -> Pin<
357 Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
358 > {
359 self.stream_with_model(
360 &self.model,
361 messages,
362 system,
363 tools,
364 max_tokens,
365 thinking_budget,
366 )
367 }
368
369 fn stream_with_model(
370 &self,
371 model: &str,
372 messages: &[Message],
373 system: Option<&str>,
374 tools: &[ToolDefinition],
375 max_tokens: u32,
376 _thinking_budget: u32,
377 ) -> Pin<
378 Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
379 > {
380 let messages = messages.to_vec();
381 let system = system.map(String::from);
382 let tools = tools.to_vec();
383 let model = model.to_string();
384 let client = self.client.clone();
385
386 Box::pin(async move {
387 let converted_messages = convert_messages(&messages, system.as_deref())
388 .context("Failed to convert messages")?;
389 let converted_tools = convert_tools(&tools);
390
391 let request = CreateChatCompletionRequest {
392 model: model.clone(),
393 messages: converted_messages,
394 max_completion_tokens: Some(max_tokens),
395 stream: Some(true),
396 tools: if converted_tools.is_empty() {
397 None
398 } else {
399 Some(converted_tools)
400 },
401 temperature: Some(1.0),
402 ..Default::default()
403 };
404
405 let mut oai_stream = client
406 .chat()
407 .create_stream(request)
408 .await
409 .context("Failed to create OpenAI stream")?;
410
411 let (tx, rx) = mpsc::unbounded_channel::<StreamEvent>();
412 let tx_clone = tx.clone();
413
414 tokio::spawn(async move {
415 let mut tool_accum: HashMap<u32, ToolCallAccum> = HashMap::new();
416 let mut total_output_tokens: u32 = 0;
417 let mut final_stop_reason: Option<StopReason> = None;
418
419 let _ = tx_clone.send(StreamEvent {
420 event_type: StreamEventType::MessageStart,
421 });
422
423 while let Some(result) = oai_stream.next().await {
424 match result {
425 Err(e) => {
426 warn!("OpenAI stream error: {e}");
427 let _ = tx_clone.send(StreamEvent {
428 event_type: StreamEventType::Error(e.to_string()),
429 });
430 return;
431 }
432 Ok(response) => {
433 if let Some(usage) = response.usage {
434 total_output_tokens = usage.completion_tokens;
435 }
436
437 for choice in response.choices {
438 if let Some(reason) = &choice.finish_reason {
439 final_stop_reason = Some(map_finish_reason(reason));
440
441 if matches!(
442 reason,
443 FinishReason::ToolCalls | FinishReason::FunctionCall
444 ) {
445 for accum in tool_accum.values() {
446 if accum.started {
447 let _ = tx_clone.send(StreamEvent {
448 event_type: StreamEventType::ToolUseEnd,
449 });
450 }
451 }
452 tool_accum.clear();
453 }
454 }
455
456 let delta = choice.delta;
457
458 if let Some(content) = delta.content
459 && !content.is_empty()
460 {
461 let _ = tx_clone.send(StreamEvent {
462 event_type: StreamEventType::TextDelta(content),
463 });
464 }
465
466 if let Some(tool_call_chunks) = delta.tool_calls {
467 for chunk in tool_call_chunks {
468 let idx = chunk.index;
469 let entry = tool_accum.entry(idx).or_default();
470
471 if let Some(id) = chunk.id
472 && !id.is_empty()
473 {
474 entry.id = id;
475 }
476
477 if let Some(func) = chunk.function {
478 if let Some(name) = func.name
479 && !name.is_empty()
480 {
481 entry.name = name;
482 }
483
484 if !entry.started
485 && !entry.id.is_empty()
486 && !entry.name.is_empty()
487 {
488 let _ = tx_clone.send(StreamEvent {
489 event_type: StreamEventType::ToolUseStart {
490 id: entry.id.clone(),
491 name: entry.name.clone(),
492 },
493 });
494 entry.started = true;
495 debug!(
496 "OpenAI tool use start: id={} name={}",
497 entry.id, entry.name
498 );
499 }
500
501 if let Some(args) = func.arguments
502 && !args.is_empty()
503 {
504 entry.arguments.push_str(&args);
505 let _ = tx_clone.send(StreamEvent {
506 event_type: StreamEventType::ToolUseInputDelta(
507 args,
508 ),
509 });
510 }
511 }
512 }
513 }
514 }
515 }
516 }
517 }
518
519 for accum in tool_accum.values() {
520 if accum.started {
521 let _ = tx_clone.send(StreamEvent {
522 event_type: StreamEventType::ToolUseEnd,
523 });
524 }
525 }
526
527 let stop = final_stop_reason.unwrap_or(StopReason::EndTurn);
528 let _ = tx_clone.send(StreamEvent {
529 event_type: StreamEventType::MessageEnd {
530 stop_reason: stop,
531 usage: Usage {
532 input_tokens: 0,
533 output_tokens: total_output_tokens,
534 cache_read_tokens: 0,
535 cache_write_tokens: 0,
536 },
537 },
538 });
539 });
540
541 Ok(rx)
542 })
543 }
544}