Skip to main content

dot/provider/
openai.rs

1use std::{collections::HashMap, future::Future, pin::Pin};
2
3use anyhow::Context;
4use async_openai::{
5    Client,
6    config::OpenAIConfig,
7    types::{
8        ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
9        ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
10        ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText,
11        ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
12        ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
13        ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
14        ChatCompletionRequestUserMessageContentPart, ChatCompletionTool, ChatCompletionToolType,
15        CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionObject, ImageUrl,
16    },
17};
18use futures::StreamExt;
19use tokio::sync::mpsc;
20use tracing::{debug, warn};
21
22use crate::provider::{
23    ContentBlock, Message, Provider, Role, StopReason, StreamEvent, StreamEventType,
24    ToolDefinition, Usage,
25};
26
27pub struct OpenAIProvider {
28    client: Client<OpenAIConfig>,
29    model: String,
30    cached_models: std::sync::Mutex<Option<Vec<String>>>,
31}
32
33impl OpenAIProvider {
34    pub fn new(model: impl Into<String>) -> Self {
35        Self {
36            client: Client::new(),
37            model: model.into(),
38            cached_models: std::sync::Mutex::new(None),
39        }
40    }
41    pub fn new_with_config(config: OpenAIConfig, model: impl Into<String>) -> Self {
42        Self {
43            client: Client::with_config(config),
44            model: model.into(),
45            cached_models: std::sync::Mutex::new(None),
46        }
47    }
48}
49
50#[derive(Default)]
51struct ToolCallAccum {
52    id: String,
53    name: String,
54    arguments: String,
55    started: bool,
56}
57
58fn convert_messages(
59    messages: &[Message],
60    system: Option<&str>,
61) -> anyhow::Result<Vec<ChatCompletionRequestMessage>> {
62    let mut result: Vec<ChatCompletionRequestMessage> = Vec::new();
63
64    if let Some(sys) = system {
65        result.push(ChatCompletionRequestMessage::System(
66            ChatCompletionRequestSystemMessage {
67                content: ChatCompletionRequestSystemMessageContent::Text(sys.to_string()),
68                name: None,
69            },
70        ));
71    }
72
73    for msg in messages {
74        match msg.role {
75            Role::System => {
76                let text = extract_text_content(&msg.content);
77                result.push(ChatCompletionRequestMessage::System(
78                    ChatCompletionRequestSystemMessage {
79                        content: ChatCompletionRequestSystemMessageContent::Text(text),
80                        name: None,
81                    },
82                ));
83            }
84
85            Role::User => {
86                let mut tool_results: Vec<(String, String)> = Vec::new();
87                let mut texts: Vec<String> = Vec::new();
88                let mut images: Vec<(String, String)> = Vec::new();
89
90                for block in &msg.content {
91                    match block {
92                        ContentBlock::Text(t) => texts.push(t.clone()),
93                        ContentBlock::Image { media_type, data } => {
94                            images.push((media_type.clone(), data.clone()));
95                        }
96                        ContentBlock::ToolResult {
97                            tool_use_id,
98                            content,
99                            ..
100                        } => {
101                            tool_results.push((tool_use_id.clone(), content.clone()));
102                        }
103                        _ => {}
104                    }
105                }
106
107                for (id, content) in tool_results {
108                    result.push(ChatCompletionRequestMessage::Tool(
109                        ChatCompletionRequestToolMessage {
110                            content: ChatCompletionRequestToolMessageContent::Text(content),
111                            tool_call_id: id,
112                        },
113                    ));
114                }
115
116                if !images.is_empty() {
117                    let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
118                    if !texts.is_empty() {
119                        parts.push(ChatCompletionRequestUserMessageContentPart::Text(
120                            ChatCompletionRequestMessageContentPartText {
121                                text: texts.join("\n"),
122                            },
123                        ));
124                    }
125                    for (media_type, data) in images {
126                        parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
127                            ChatCompletionRequestMessageContentPartImage {
128                                image_url: ImageUrl {
129                                    url: format!("data:{};base64,{}", media_type, data),
130                                    detail: None,
131                                },
132                            },
133                        ));
134                    }
135                    result.push(ChatCompletionRequestMessage::User(
136                        ChatCompletionRequestUserMessage {
137                            content: ChatCompletionRequestUserMessageContent::Array(parts),
138                            name: None,
139                        },
140                    ));
141                } else if !texts.is_empty() {
142                    result.push(ChatCompletionRequestMessage::User(
143                        ChatCompletionRequestUserMessage {
144                            content: ChatCompletionRequestUserMessageContent::Text(
145                                texts.join("\n"),
146                            ),
147                            name: None,
148                        },
149                    ));
150                }
151            }
152
153            Role::Assistant => {
154                let mut text_parts: Vec<String> = Vec::new();
155                let mut tool_calls: Vec<ChatCompletionMessageToolCall> = Vec::new();
156
157                for block in &msg.content {
158                    match block {
159                        ContentBlock::Text(t) => text_parts.push(t.clone()),
160                        ContentBlock::ToolUse { id, name, input } => {
161                            tool_calls.push(ChatCompletionMessageToolCall {
162                                id: id.clone(),
163                                r#type: ChatCompletionToolType::Function,
164                                function: FunctionCall {
165                                    name: name.clone(),
166                                    arguments: serde_json::to_string(input).unwrap_or_default(),
167                                },
168                            });
169                        }
170                        _ => {}
171                    }
172                }
173
174                let content = if text_parts.is_empty() {
175                    None
176                } else {
177                    Some(ChatCompletionRequestAssistantMessageContent::Text(
178                        text_parts.join("\n"),
179                    ))
180                };
181
182                result.push(ChatCompletionRequestMessage::Assistant(
183                    ChatCompletionRequestAssistantMessage {
184                        content,
185                        name: None,
186                        tool_calls: if tool_calls.is_empty() {
187                            None
188                        } else {
189                            Some(tool_calls)
190                        },
191                        refusal: None,
192                        ..Default::default()
193                    },
194                ));
195            }
196        }
197    }
198
199    Ok(result)
200}
201
202fn extract_text_content(blocks: &[ContentBlock]) -> String {
203    blocks
204        .iter()
205        .filter_map(|b| {
206            if let ContentBlock::Text(t) = b {
207                Some(t.as_str())
208            } else {
209                None
210            }
211        })
212        .collect::<Vec<_>>()
213        .join("\n")
214}
215
216fn convert_tools(tools: &[ToolDefinition]) -> Vec<ChatCompletionTool> {
217    tools
218        .iter()
219        .map(|t| ChatCompletionTool {
220            r#type: ChatCompletionToolType::Function,
221            function: FunctionObject {
222                name: t.name.clone(),
223                description: Some(t.description.clone()),
224                parameters: Some(t.input_schema.clone()),
225                strict: None,
226            },
227        })
228        .collect()
229}
230
231fn map_finish_reason(reason: &FinishReason) -> StopReason {
232    match reason {
233        FinishReason::Stop => StopReason::EndTurn,
234        FinishReason::Length => StopReason::MaxTokens,
235        FinishReason::ToolCalls | FinishReason::FunctionCall => StopReason::ToolUse,
236        FinishReason::ContentFilter => StopReason::StopSequence,
237    }
238}
239
240impl Provider for OpenAIProvider {
241    fn name(&self) -> &str {
242        "openai"
243    }
244
245    fn model(&self) -> &str {
246        &self.model
247    }
248
249    fn set_model(&mut self, model: String) {
250        self.model = model;
251    }
252
253    fn available_models(&self) -> Vec<String> {
254        let cache = self.cached_models.lock().unwrap();
255        cache.clone().unwrap_or_default()
256    }
257
258    fn context_window(&self) -> u32 {
259        0
260    }
261
262    fn fetch_context_window(
263        &self,
264    ) -> Pin<Box<dyn Future<Output = anyhow::Result<u32>> + Send + '_>> {
265        Box::pin(async move { Ok(0) })
266    }
267
268    fn fetch_models(
269        &self,
270    ) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<String>>> + Send + '_>> {
271        let client = self.client.clone();
272        Box::pin(async move {
273            {
274                let cache = self.cached_models.lock().unwrap();
275                if let Some(ref models) = *cache {
276                    return Ok(models.clone());
277                }
278            }
279
280            let resp = client.models().list().await;
281
282            match resp {
283                Ok(list) => {
284                    let mut models: Vec<String> = list
285                        .data
286                        .into_iter()
287                        .map(|m| m.id)
288                        .filter(|id| {
289                            id.starts_with("gpt-")
290                                || id.starts_with("o1")
291                                || id.starts_with("o3")
292                                || id.starts_with("o4")
293                        })
294                        .collect();
295                    models.sort();
296                    models.dedup();
297
298                    if models.is_empty() {
299                        return Err(anyhow::anyhow!(
300                            "OpenAI models API returned no matching models"
301                        ));
302                    }
303
304                    let mut cache = self.cached_models.lock().unwrap();
305                    *cache = Some(models.clone());
306                    Ok(models)
307                }
308                Err(e) => Err(anyhow::anyhow!("Failed to fetch OpenAI models: {e}")),
309            }
310        })
311    }
312
313    fn stream(
314        &self,
315        messages: &[Message],
316        system: Option<&str>,
317        tools: &[ToolDefinition],
318        max_tokens: u32,
319        thinking_budget: u32,
320    ) -> Pin<
321        Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
322    > {
323        self.stream_with_model(
324            &self.model,
325            messages,
326            system,
327            tools,
328            max_tokens,
329            thinking_budget,
330        )
331    }
332
333    fn stream_with_model(
334        &self,
335        model: &str,
336        messages: &[Message],
337        system: Option<&str>,
338        tools: &[ToolDefinition],
339        max_tokens: u32,
340        _thinking_budget: u32,
341    ) -> Pin<
342        Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
343    > {
344        let messages = messages.to_vec();
345        let system = system.map(String::from);
346        let tools = tools.to_vec();
347        let model = model.to_string();
348        let client = self.client.clone();
349
350        Box::pin(async move {
351            let converted_messages = convert_messages(&messages, system.as_deref())
352                .context("Failed to convert messages")?;
353            let converted_tools = convert_tools(&tools);
354
355            let request = CreateChatCompletionRequest {
356                model: model.clone(),
357                messages: converted_messages,
358                max_completion_tokens: Some(max_tokens),
359                stream: Some(true),
360                tools: if converted_tools.is_empty() {
361                    None
362                } else {
363                    Some(converted_tools)
364                },
365                temperature: Some(1.0),
366                ..Default::default()
367            };
368
369            let mut oai_stream = client
370                .chat()
371                .create_stream(request)
372                .await
373                .context("Failed to create OpenAI stream")?;
374
375            let (tx, rx) = mpsc::unbounded_channel::<StreamEvent>();
376            let tx_clone = tx.clone();
377
378            tokio::spawn(async move {
379                let mut tool_accum: HashMap<u32, ToolCallAccum> = HashMap::new();
380                let mut total_output_tokens: u32 = 0;
381                let mut final_stop_reason: Option<StopReason> = None;
382
383                let _ = tx_clone.send(StreamEvent {
384                    event_type: StreamEventType::MessageStart,
385                });
386
387                while let Some(result) = oai_stream.next().await {
388                    match result {
389                        Err(e) => {
390                            warn!("OpenAI stream error: {e}");
391                            let _ = tx_clone.send(StreamEvent {
392                                event_type: StreamEventType::Error(e.to_string()),
393                            });
394                            return;
395                        }
396                        Ok(response) => {
397                            if let Some(usage) = response.usage {
398                                total_output_tokens = usage.completion_tokens;
399                            }
400
401                            for choice in response.choices {
402                                if let Some(reason) = &choice.finish_reason {
403                                    final_stop_reason = Some(map_finish_reason(reason));
404
405                                    if matches!(
406                                        reason,
407                                        FinishReason::ToolCalls | FinishReason::FunctionCall
408                                    ) {
409                                        for accum in tool_accum.values() {
410                                            if accum.started {
411                                                let _ = tx_clone.send(StreamEvent {
412                                                    event_type: StreamEventType::ToolUseEnd,
413                                                });
414                                            }
415                                        }
416                                        tool_accum.clear();
417                                    }
418                                }
419
420                                let delta = choice.delta;
421
422                                if let Some(content) = delta.content
423                                    && !content.is_empty()
424                                {
425                                    let _ = tx_clone.send(StreamEvent {
426                                        event_type: StreamEventType::TextDelta(content),
427                                    });
428                                }
429
430                                if let Some(tool_call_chunks) = delta.tool_calls {
431                                    for chunk in tool_call_chunks {
432                                        let idx = chunk.index;
433                                        let entry = tool_accum.entry(idx).or_default();
434
435                                        if let Some(id) = chunk.id
436                                            && !id.is_empty()
437                                        {
438                                            entry.id = id;
439                                        }
440
441                                        if let Some(func) = chunk.function {
442                                            if let Some(name) = func.name
443                                                && !name.is_empty()
444                                            {
445                                                entry.name = name;
446                                            }
447
448                                            if !entry.started
449                                                && !entry.id.is_empty()
450                                                && !entry.name.is_empty()
451                                            {
452                                                let _ = tx_clone.send(StreamEvent {
453                                                    event_type: StreamEventType::ToolUseStart {
454                                                        id: entry.id.clone(),
455                                                        name: entry.name.clone(),
456                                                    },
457                                                });
458                                                entry.started = true;
459                                                debug!(
460                                                    "OpenAI tool use start: id={} name={}",
461                                                    entry.id, entry.name
462                                                );
463                                            }
464
465                                            if let Some(args) = func.arguments
466                                                && !args.is_empty()
467                                            {
468                                                entry.arguments.push_str(&args);
469                                                let _ = tx_clone.send(StreamEvent {
470                                                    event_type: StreamEventType::ToolUseInputDelta(
471                                                        args,
472                                                    ),
473                                                });
474                                            }
475                                        }
476                                    }
477                                }
478                            }
479                        }
480                    }
481                }
482
483                for accum in tool_accum.values() {
484                    if accum.started {
485                        let _ = tx_clone.send(StreamEvent {
486                            event_type: StreamEventType::ToolUseEnd,
487                        });
488                    }
489                }
490
491                let stop = final_stop_reason.unwrap_or(StopReason::EndTurn);
492                let _ = tx_clone.send(StreamEvent {
493                    event_type: StreamEventType::MessageEnd {
494                        stop_reason: stop,
495                        usage: Usage {
496                            input_tokens: 0,
497                            output_tokens: total_output_tokens,
498                            cache_read_tokens: 0,
499                            cache_write_tokens: 0,
500                        },
501                    },
502                });
503            });
504
505            Ok(rx)
506        })
507    }
508}