Skip to main content

dot/provider/
openai.rs

1use std::{collections::HashMap, future::Future, pin::Pin};
2
3use anyhow::Context;
4use async_openai::{
5    Client,
6    config::OpenAIConfig,
7    types::{
8        ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
9        ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
10        ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText,
11        ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
12        ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
13        ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
14        ChatCompletionRequestUserMessageContentPart, ChatCompletionTool, ChatCompletionToolType,
15        CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionObject, ImageUrl,
16    },
17};
18use futures::StreamExt;
19use tokio::sync::mpsc;
20use tracing::{debug, warn};
21
22use crate::provider::{
23    ContentBlock, Message, Provider, Role, StopReason, StreamEvent, StreamEventType,
24    ToolDefinition, Usage,
25};
26
27pub struct OpenAIProvider {
28    client: Client<OpenAIConfig>,
29    model: String,
30    cached_models: std::sync::Mutex<Option<Vec<String>>>,
31    context_windows: std::sync::Mutex<HashMap<String, u32>>,
32}
33
34fn known_context_window(model: &str) -> u32 {
35    if model.starts_with("o1") || model.starts_with("o3") || model.starts_with("o4") {
36        return 200_000;
37    }
38    if model.starts_with("gpt-4o")
39        || model.starts_with("gpt-4-turbo")
40        || model == "gpt-4-1106-preview"
41    {
42        return 128_000;
43    }
44    if model.starts_with("gpt-4") {
45        return 8_192;
46    }
47    if model.starts_with("gpt-3.5") {
48        return 16_385;
49    }
50    0
51}
52
53impl OpenAIProvider {
54    pub fn new(model: impl Into<String>) -> Self {
55        Self {
56            client: Client::new(),
57            model: model.into(),
58            cached_models: std::sync::Mutex::new(None),
59            context_windows: std::sync::Mutex::new(HashMap::new()),
60        }
61    }
62    pub fn new_with_config(config: OpenAIConfig, model: impl Into<String>) -> Self {
63        Self {
64            client: Client::with_config(config),
65            model: model.into(),
66            cached_models: std::sync::Mutex::new(None),
67            context_windows: std::sync::Mutex::new(HashMap::new()),
68        }
69    }
70}
71
72#[derive(Default)]
73struct ToolCallAccum {
74    id: String,
75    name: String,
76    arguments: String,
77    started: bool,
78}
79
80fn convert_messages(
81    messages: &[Message],
82    system: Option<&str>,
83) -> anyhow::Result<Vec<ChatCompletionRequestMessage>> {
84    let mut result: Vec<ChatCompletionRequestMessage> = Vec::new();
85
86    if let Some(sys) = system {
87        result.push(ChatCompletionRequestMessage::System(
88            ChatCompletionRequestSystemMessage {
89                content: ChatCompletionRequestSystemMessageContent::Text(sys.to_string()),
90                name: None,
91            },
92        ));
93    }
94
95    for msg in messages {
96        match msg.role {
97            Role::System => {
98                let text = extract_text_content(&msg.content);
99                result.push(ChatCompletionRequestMessage::System(
100                    ChatCompletionRequestSystemMessage {
101                        content: ChatCompletionRequestSystemMessageContent::Text(text),
102                        name: None,
103                    },
104                ));
105            }
106
107            Role::User => {
108                let mut tool_results: Vec<(String, String)> = Vec::new();
109                let mut texts: Vec<String> = Vec::new();
110                let mut images: Vec<(String, String)> = Vec::new();
111
112                for block in &msg.content {
113                    match block {
114                        ContentBlock::Text(t) => texts.push(t.clone()),
115                        ContentBlock::Image { media_type, data } => {
116                            images.push((media_type.clone(), data.clone()));
117                        }
118                        ContentBlock::ToolResult {
119                            tool_use_id,
120                            content,
121                            ..
122                        } => {
123                            tool_results.push((tool_use_id.clone(), content.clone()));
124                        }
125                        _ => {}
126                    }
127                }
128
129                for (id, content) in tool_results {
130                    result.push(ChatCompletionRequestMessage::Tool(
131                        ChatCompletionRequestToolMessage {
132                            content: ChatCompletionRequestToolMessageContent::Text(content),
133                            tool_call_id: id,
134                        },
135                    ));
136                }
137
138                if !images.is_empty() {
139                    let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
140                    if !texts.is_empty() {
141                        parts.push(ChatCompletionRequestUserMessageContentPart::Text(
142                            ChatCompletionRequestMessageContentPartText {
143                                text: texts.join("\n"),
144                            },
145                        ));
146                    }
147                    for (media_type, data) in images {
148                        parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
149                            ChatCompletionRequestMessageContentPartImage {
150                                image_url: ImageUrl {
151                                    url: format!("data:{};base64,{}", media_type, data),
152                                    detail: None,
153                                },
154                            },
155                        ));
156                    }
157                    result.push(ChatCompletionRequestMessage::User(
158                        ChatCompletionRequestUserMessage {
159                            content: ChatCompletionRequestUserMessageContent::Array(parts),
160                            name: None,
161                        },
162                    ));
163                } else if !texts.is_empty() {
164                    result.push(ChatCompletionRequestMessage::User(
165                        ChatCompletionRequestUserMessage {
166                            content: ChatCompletionRequestUserMessageContent::Text(
167                                texts.join("\n"),
168                            ),
169                            name: None,
170                        },
171                    ));
172                }
173            }
174
175            Role::Assistant => {
176                let mut text_parts: Vec<String> = Vec::new();
177                let mut tool_calls: Vec<ChatCompletionMessageToolCall> = Vec::new();
178
179                for block in &msg.content {
180                    match block {
181                        ContentBlock::Text(t) => text_parts.push(t.clone()),
182                        ContentBlock::ToolUse { id, name, input } => {
183                            tool_calls.push(ChatCompletionMessageToolCall {
184                                id: id.clone(),
185                                r#type: ChatCompletionToolType::Function,
186                                function: FunctionCall {
187                                    name: name.clone(),
188                                    arguments: serde_json::to_string(input).unwrap_or_default(),
189                                },
190                            });
191                        }
192                        _ => {}
193                    }
194                }
195
196                let content = if text_parts.is_empty() {
197                    None
198                } else {
199                    Some(ChatCompletionRequestAssistantMessageContent::Text(
200                        text_parts.join("\n"),
201                    ))
202                };
203
204                result.push(ChatCompletionRequestMessage::Assistant(
205                    ChatCompletionRequestAssistantMessage {
206                        content,
207                        name: None,
208                        tool_calls: if tool_calls.is_empty() {
209                            None
210                        } else {
211                            Some(tool_calls)
212                        },
213                        refusal: None,
214                        ..Default::default()
215                    },
216                ));
217            }
218        }
219    }
220
221    Ok(result)
222}
223
224fn extract_text_content(blocks: &[ContentBlock]) -> String {
225    blocks
226        .iter()
227        .filter_map(|b| {
228            if let ContentBlock::Text(t) = b {
229                Some(t.as_str())
230            } else {
231                None
232            }
233        })
234        .collect::<Vec<_>>()
235        .join("\n")
236}
237
238fn convert_tools(tools: &[ToolDefinition]) -> Vec<ChatCompletionTool> {
239    tools
240        .iter()
241        .map(|t| ChatCompletionTool {
242            r#type: ChatCompletionToolType::Function,
243            function: FunctionObject {
244                name: t.name.clone(),
245                description: Some(t.description.clone()),
246                parameters: Some(t.input_schema.clone()),
247                strict: None,
248            },
249        })
250        .collect()
251}
252
253fn map_finish_reason(reason: &FinishReason) -> StopReason {
254    match reason {
255        FinishReason::Stop => StopReason::EndTurn,
256        FinishReason::Length => StopReason::MaxTokens,
257        FinishReason::ToolCalls | FinishReason::FunctionCall => StopReason::ToolUse,
258        FinishReason::ContentFilter => StopReason::StopSequence,
259    }
260}
261
262impl Provider for OpenAIProvider {
263    fn name(&self) -> &str {
264        "openai"
265    }
266
267    fn model(&self) -> &str {
268        &self.model
269    }
270
271    fn set_model(&mut self, model: String) {
272        self.model = model;
273    }
274
275    fn available_models(&self) -> Vec<String> {
276        let cache = self.cached_models.lock().unwrap();
277        cache.clone().unwrap_or_default()
278    }
279
280    fn context_window(&self) -> u32 {
281        let cw = self.context_windows.lock().unwrap();
282        cw.get(&self.model).copied().unwrap_or(0)
283    }
284
285    fn fetch_context_window(
286        &self,
287    ) -> Pin<Box<dyn Future<Output = anyhow::Result<u32>> + Send + '_>> {
288        Box::pin(async move {
289            {
290                let cw = self.context_windows.lock().unwrap();
291                if let Some(&v) = cw.get(&self.model) {
292                    return Ok(v);
293                }
294            }
295            let v = known_context_window(&self.model);
296            if v > 0 {
297                let mut cw = self.context_windows.lock().unwrap();
298                cw.insert(self.model.clone(), v);
299            }
300            Ok(v)
301        })
302    }
303
304    fn fetch_models(
305        &self,
306    ) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<String>>> + Send + '_>> {
307        let client = self.client.clone();
308        Box::pin(async move {
309            {
310                let cache = self.cached_models.lock().unwrap();
311                if let Some(ref models) = *cache {
312                    return Ok(models.clone());
313                }
314            }
315
316            let resp = client.models().list().await;
317
318            match resp {
319                Ok(list) => {
320                    let mut models: Vec<String> = list
321                        .data
322                        .into_iter()
323                        .map(|m| m.id)
324                        .filter(|id| {
325                            id.starts_with("gpt-")
326                                || id.starts_with("o1")
327                                || id.starts_with("o3")
328                                || id.starts_with("o4")
329                        })
330                        .collect();
331                    models.sort();
332                    models.dedup();
333
334                    if models.is_empty() {
335                        return Err(anyhow::anyhow!(
336                            "OpenAI models API returned no matching models"
337                        ));
338                    }
339
340                    let mut cache = self.cached_models.lock().unwrap();
341                    *cache = Some(models.clone());
342                    Ok(models)
343                }
344                Err(e) => Err(anyhow::anyhow!("Failed to fetch OpenAI models: {e}")),
345            }
346        })
347    }
348
349    fn stream(
350        &self,
351        messages: &[Message],
352        system: Option<&str>,
353        tools: &[ToolDefinition],
354        max_tokens: u32,
355        thinking_budget: u32,
356    ) -> Pin<
357        Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
358    > {
359        self.stream_with_model(
360            &self.model,
361            messages,
362            system,
363            tools,
364            max_tokens,
365            thinking_budget,
366        )
367    }
368
369    fn stream_with_model(
370        &self,
371        model: &str,
372        messages: &[Message],
373        system: Option<&str>,
374        tools: &[ToolDefinition],
375        max_tokens: u32,
376        _thinking_budget: u32,
377    ) -> Pin<
378        Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
379    > {
380        let messages = messages.to_vec();
381        let system = system.map(String::from);
382        let tools = tools.to_vec();
383        let model = model.to_string();
384        let client = self.client.clone();
385
386        Box::pin(async move {
387            let converted_messages = convert_messages(&messages, system.as_deref())
388                .context("Failed to convert messages")?;
389            let converted_tools = convert_tools(&tools);
390
391            let request = CreateChatCompletionRequest {
392                model: model.clone(),
393                messages: converted_messages,
394                max_completion_tokens: Some(max_tokens),
395                stream: Some(true),
396                tools: if converted_tools.is_empty() {
397                    None
398                } else {
399                    Some(converted_tools)
400                },
401                temperature: Some(1.0),
402                ..Default::default()
403            };
404
405            let mut oai_stream = client
406                .chat()
407                .create_stream(request)
408                .await
409                .context("Failed to create OpenAI stream")?;
410
411            let (tx, rx) = mpsc::unbounded_channel::<StreamEvent>();
412            let tx_clone = tx.clone();
413
414            tokio::spawn(async move {
415                let mut tool_accum: HashMap<u32, ToolCallAccum> = HashMap::new();
416                let mut total_output_tokens: u32 = 0;
417                let mut final_stop_reason: Option<StopReason> = None;
418
419                let _ = tx_clone.send(StreamEvent {
420                    event_type: StreamEventType::MessageStart,
421                });
422
423                while let Some(result) = oai_stream.next().await {
424                    match result {
425                        Err(e) => {
426                            warn!("OpenAI stream error: {e}");
427                            let _ = tx_clone.send(StreamEvent {
428                                event_type: StreamEventType::Error(e.to_string()),
429                            });
430                            return;
431                        }
432                        Ok(response) => {
433                            if let Some(usage) = response.usage {
434                                total_output_tokens = usage.completion_tokens;
435                            }
436
437                            for choice in response.choices {
438                                if let Some(reason) = &choice.finish_reason {
439                                    final_stop_reason = Some(map_finish_reason(reason));
440
441                                    if matches!(
442                                        reason,
443                                        FinishReason::ToolCalls | FinishReason::FunctionCall
444                                    ) {
445                                        for accum in tool_accum.values() {
446                                            if accum.started {
447                                                let _ = tx_clone.send(StreamEvent {
448                                                    event_type: StreamEventType::ToolUseEnd,
449                                                });
450                                            }
451                                        }
452                                        tool_accum.clear();
453                                    }
454                                }
455
456                                let delta = choice.delta;
457
458                                if let Some(content) = delta.content
459                                    && !content.is_empty()
460                                {
461                                    let _ = tx_clone.send(StreamEvent {
462                                        event_type: StreamEventType::TextDelta(content),
463                                    });
464                                }
465
466                                if let Some(tool_call_chunks) = delta.tool_calls {
467                                    for chunk in tool_call_chunks {
468                                        let idx = chunk.index;
469                                        let entry = tool_accum.entry(idx).or_default();
470
471                                        if let Some(id) = chunk.id
472                                            && !id.is_empty()
473                                        {
474                                            entry.id = id;
475                                        }
476
477                                        if let Some(func) = chunk.function {
478                                            if let Some(name) = func.name
479                                                && !name.is_empty()
480                                            {
481                                                entry.name = name;
482                                            }
483
484                                            if !entry.started
485                                                && !entry.id.is_empty()
486                                                && !entry.name.is_empty()
487                                            {
488                                                let _ = tx_clone.send(StreamEvent {
489                                                    event_type: StreamEventType::ToolUseStart {
490                                                        id: entry.id.clone(),
491                                                        name: entry.name.clone(),
492                                                    },
493                                                });
494                                                entry.started = true;
495                                                debug!(
496                                                    "OpenAI tool use start: id={} name={}",
497                                                    entry.id, entry.name
498                                                );
499                                            }
500
501                                            if let Some(args) = func.arguments
502                                                && !args.is_empty()
503                                            {
504                                                entry.arguments.push_str(&args);
505                                                let _ = tx_clone.send(StreamEvent {
506                                                    event_type: StreamEventType::ToolUseInputDelta(
507                                                        args,
508                                                    ),
509                                                });
510                                            }
511                                        }
512                                    }
513                                }
514                            }
515                        }
516                    }
517                }
518
519                for accum in tool_accum.values() {
520                    if accum.started {
521                        let _ = tx_clone.send(StreamEvent {
522                            event_type: StreamEventType::ToolUseEnd,
523                        });
524                    }
525                }
526
527                let stop = final_stop_reason.unwrap_or(StopReason::EndTurn);
528                let _ = tx_clone.send(StreamEvent {
529                    event_type: StreamEventType::MessageEnd {
530                        stop_reason: stop,
531                        usage: Usage {
532                            input_tokens: 0,
533                            output_tokens: total_output_tokens,
534                            cache_read_tokens: 0,
535                            cache_write_tokens: 0,
536                        },
537                    },
538                });
539            });
540
541            Ok(rx)
542        })
543    }
544}