Skip to main content

codetether_agent/provider/
openai.rs

1//! OpenAI provider implementation
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9    Client,
10    config::OpenAIConfig,
11    types::chat::{
12        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16        CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17        FunctionObjectArgs,
18    },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24    client: Client<OpenAIConfig>,
25    provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("OpenAIProvider")
31            .field("provider_name", &self.provider_name)
32            .field("client", &"<async_openai::Client>")
33            .finish()
34    }
35}
36
37impl OpenAIProvider {
38    pub fn new(api_key: String) -> Result<Self> {
39        tracing::debug!(
40            provider = "openai",
41            api_key_len = api_key.len(),
42            "Creating OpenAI provider"
43        );
44        let config = OpenAIConfig::new().with_api_key(api_key);
45        Ok(Self {
46            client: Client::with_config(config),
47            provider_name: "openai".to_string(),
48        })
49    }
50
51    /// Create with custom base URL (for OpenAI-compatible providers like Moonshot)
52    pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53        tracing::debug!(
54            provider = provider_name,
55            base_url = %base_url,
56            api_key_len = api_key.len(),
57            "Creating OpenAI-compatible provider"
58        );
59        let config = OpenAIConfig::new()
60            .with_api_key(api_key)
61            .with_api_base(base_url);
62        Ok(Self {
63            client: Client::with_config(config),
64            provider_name: provider_name.to_string(),
65        })
66    }
67
68    fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
69        let mut result = Vec::new();
70
71        for msg in messages {
72            let content = msg
73                .content
74                .iter()
75                .filter_map(|p| match p {
76                    ContentPart::Text { text } => Some(text.clone()),
77                    _ => None,
78                })
79                .collect::<Vec<_>>()
80                .join("\n");
81
82            match msg.role {
83                Role::System => {
84                    result.push(
85                        ChatCompletionRequestSystemMessageArgs::default()
86                            .content(content)
87                            .build()?
88                            .into(),
89                    );
90                }
91                Role::User => {
92                    result.push(
93                        ChatCompletionRequestUserMessageArgs::default()
94                            .content(content)
95                            .build()?
96                            .into(),
97                    );
98                }
99                Role::Assistant => {
100                    let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
101                        .content
102                        .iter()
103                        .filter_map(|p| match p {
104                            ContentPart::ToolCall {
105                                id,
106                                name,
107                                arguments,
108                            } => Some(ChatCompletionMessageToolCalls::Function(
109                                ChatCompletionMessageToolCall {
110                                    id: id.clone(),
111                                    function: FunctionCall {
112                                        name: name.clone(),
113                                        arguments: arguments.clone(),
114                                    },
115                                },
116                            )),
117                            _ => None,
118                        })
119                        .collect();
120
121                    let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
122                    if !content.is_empty() {
123                        builder.content(content);
124                    }
125                    if !tool_calls.is_empty() {
126                        builder.tool_calls(tool_calls);
127                    }
128                    result.push(builder.build()?.into());
129                }
130                Role::Tool => {
131                    for part in &msg.content {
132                        if let ContentPart::ToolResult {
133                            tool_call_id,
134                            content,
135                        } = part
136                        {
137                            result.push(
138                                ChatCompletionRequestToolMessageArgs::default()
139                                    .tool_call_id(tool_call_id.clone())
140                                    .content(content.clone())
141                                    .build()?
142                                    .into(),
143                            );
144                        }
145                    }
146                }
147            }
148        }
149
150        Ok(result)
151    }
152
153    fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
154        let mut result = Vec::new();
155        for tool in tools {
156            result.push(ChatCompletionTools::Function(ChatCompletionTool {
157                function: FunctionObjectArgs::default()
158                    .name(&tool.name)
159                    .description(&tool.description)
160                    .parameters(tool.parameters.clone())
161                    .build()?,
162            }));
163        }
164        Ok(result)
165    }
166}
167
168#[async_trait]
169impl Provider for OpenAIProvider {
170    fn name(&self) -> &str {
171        &self.provider_name
172    }
173
174    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
175        // Return commonly used models
176        Ok(vec![
177            ModelInfo {
178                id: "gpt-4o".to_string(),
179                name: "GPT-4o".to_string(),
180                provider: "openai".to_string(),
181                context_window: 128_000,
182                max_output_tokens: Some(16_384),
183                supports_vision: true,
184                supports_tools: true,
185                supports_streaming: true,
186                input_cost_per_million: Some(2.5),
187                output_cost_per_million: Some(10.0),
188            },
189            ModelInfo {
190                id: "gpt-4o-mini".to_string(),
191                name: "GPT-4o Mini".to_string(),
192                provider: "openai".to_string(),
193                context_window: 128_000,
194                max_output_tokens: Some(16_384),
195                supports_vision: true,
196                supports_tools: true,
197                supports_streaming: true,
198                input_cost_per_million: Some(0.15),
199                output_cost_per_million: Some(0.6),
200            },
201            ModelInfo {
202                id: "o1".to_string(),
203                name: "o1".to_string(),
204                provider: "openai".to_string(),
205                context_window: 200_000,
206                max_output_tokens: Some(100_000),
207                supports_vision: true,
208                supports_tools: true,
209                supports_streaming: true,
210                input_cost_per_million: Some(15.0),
211                output_cost_per_million: Some(60.0),
212            },
213        ])
214    }
215
216    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
217        let messages = Self::convert_messages(&request.messages)?;
218        let tools = Self::convert_tools(&request.tools)?;
219
220        let mut req_builder = CreateChatCompletionRequestArgs::default();
221        req_builder.model(&request.model).messages(messages);
222
223        // Pass tools to the API if provided
224        if !tools.is_empty() {
225            req_builder.tools(tools);
226        }
227        if let Some(temp) = request.temperature {
228            req_builder.temperature(temp);
229        }
230        if let Some(max) = request.max_tokens {
231            req_builder.max_completion_tokens(max as u32);
232        }
233
234        let response = self.client.chat().create(req_builder.build()?).await?;
235
236        let choice = response
237            .choices
238            .first()
239            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
240
241        let mut content = Vec::new();
242        let mut has_tool_calls = false;
243
244        if let Some(text) = &choice.message.content {
245            content.push(ContentPart::Text { text: text.clone() });
246        }
247        if let Some(tool_calls) = &choice.message.tool_calls {
248            has_tool_calls = !tool_calls.is_empty();
249            for tc in tool_calls {
250                if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
251                    content.push(ContentPart::ToolCall {
252                        id: func_call.id.clone(),
253                        name: func_call.function.name.clone(),
254                        arguments: func_call.function.arguments.clone(),
255                    });
256                }
257            }
258        }
259
260        // Determine finish reason based on response
261        let finish_reason = if has_tool_calls {
262            FinishReason::ToolCalls
263        } else {
264            match choice.finish_reason {
265                Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
266                Some(OpenAIFinishReason::Length) => FinishReason::Length,
267                Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
268                Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
269                _ => FinishReason::Stop,
270            }
271        };
272
273        Ok(CompletionResponse {
274            message: Message {
275                role: Role::Assistant,
276                content,
277            },
278            usage: Usage {
279                prompt_tokens: response
280                    .usage
281                    .as_ref()
282                    .map(|u| u.prompt_tokens as usize)
283                    .unwrap_or(0),
284                completion_tokens: response
285                    .usage
286                    .as_ref()
287                    .map(|u| u.completion_tokens as usize)
288                    .unwrap_or(0),
289                total_tokens: response
290                    .usage
291                    .as_ref()
292                    .map(|u| u.total_tokens as usize)
293                    .unwrap_or(0),
294                ..Default::default()
295            },
296            finish_reason,
297        })
298    }
299
300    async fn complete_stream(
301        &self,
302        request: CompletionRequest,
303    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
304        tracing::debug!(
305            provider = %self.provider_name,
306            model = %request.model,
307            message_count = request.messages.len(),
308            "Starting streaming completion request"
309        );
310
311        let messages = Self::convert_messages(&request.messages)?;
312
313        let mut req_builder = CreateChatCompletionRequestArgs::default();
314        req_builder
315            .model(&request.model)
316            .messages(messages)
317            .stream(true);
318
319        if let Some(temp) = request.temperature {
320            req_builder.temperature(temp);
321        }
322
323        let stream = self
324            .client
325            .chat()
326            .create_stream(req_builder.build()?)
327            .await?;
328
329        Ok(stream
330            .map(|result| match result {
331                Ok(response) => {
332                    if let Some(choice) = response.choices.first() {
333                        if let Some(content) = &choice.delta.content {
334                            return StreamChunk::Text(content.clone());
335                        }
336                    }
337                    StreamChunk::Text(String::new())
338                }
339                Err(e) => StreamChunk::Error(e.to_string()),
340            })
341            .boxed())
342    }
343}