Skip to main content

codetether_agent/provider/
openai.rs

1//! OpenAI provider implementation
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9    Client,
10    config::OpenAIConfig,
11    types::chat::{
12        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16        CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17        FunctionObjectArgs,
18    },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24    client: Client<OpenAIConfig>,
25    provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("OpenAIProvider")
31            .field("provider_name", &self.provider_name)
32            .field("client", &"<async_openai::Client>")
33            .finish()
34    }
35}
36
37impl OpenAIProvider {
38    pub fn new(api_key: String) -> Result<Self> {
39        tracing::debug!(
40            provider = "openai",
41            api_key_len = api_key.len(),
42            "Creating OpenAI provider"
43        );
44        let config = OpenAIConfig::new().with_api_key(api_key);
45        Ok(Self {
46            client: Client::with_config(config),
47            provider_name: "openai".to_string(),
48        })
49    }
50
51    /// Create with custom base URL (for OpenAI-compatible providers like Moonshot)
52    pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53        tracing::debug!(
54            provider = provider_name,
55            base_url = %base_url,
56            api_key_len = api_key.len(),
57            "Creating OpenAI-compatible provider"
58        );
59        let config = OpenAIConfig::new()
60            .with_api_key(api_key)
61            .with_api_base(base_url);
62        Ok(Self {
63            client: Client::with_config(config),
64            provider_name: provider_name.to_string(),
65        })
66    }
67
68    /// Return known models for specific OpenAI-compatible providers
69    fn provider_default_models(&self) -> Vec<ModelInfo> {
70        let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71            "cerebras" => vec![
72                ("llama3.1-8b", "Llama 3.1 8B"),
73                ("llama-3.3-70b", "Llama 3.3 70B"),
74                ("qwen-3-32b", "Qwen 3 32B"),
75                ("gpt-oss-120b", "GPT-OSS 120B"),
76            ],
77
78            "minimax" => vec![
79                ("MiniMax-M1-80k", "MiniMax M1 80k"),
80                ("MiniMax-Text-01", "MiniMax Text 01"),
81            ],
82            "zhipuai" => vec![],
83            "novita" => vec![
84                ("qwen/qwen3-coder-next", "Qwen 3 Coder Next"),
85                ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
86                ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
87                ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
88            ],
89            _ => vec![],
90        };
91
92        models
93            .into_iter()
94            .map(|(id, name)| ModelInfo {
95                id: id.to_string(),
96                name: name.to_string(),
97                provider: self.provider_name.clone(),
98                context_window: 128_000,
99                max_output_tokens: Some(16_384),
100                supports_vision: false,
101                supports_tools: true,
102                supports_streaming: true,
103                input_cost_per_million: None,
104                output_cost_per_million: None,
105            })
106            .collect()
107    }
108
109    fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
110        let mut result = Vec::new();
111
112        for msg in messages {
113            let content = msg
114                .content
115                .iter()
116                .filter_map(|p| match p {
117                    ContentPart::Text { text } => Some(text.clone()),
118                    _ => None,
119                })
120                .collect::<Vec<_>>()
121                .join("\n");
122
123            match msg.role {
124                Role::System => {
125                    result.push(
126                        ChatCompletionRequestSystemMessageArgs::default()
127                            .content(content)
128                            .build()?
129                            .into(),
130                    );
131                }
132                Role::User => {
133                    result.push(
134                        ChatCompletionRequestUserMessageArgs::default()
135                            .content(content)
136                            .build()?
137                            .into(),
138                    );
139                }
140                Role::Assistant => {
141                    let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
142                        .content
143                        .iter()
144                        .filter_map(|p| match p {
145                            ContentPart::ToolCall {
146                                id,
147                                name,
148                                arguments,
149                            } => Some(ChatCompletionMessageToolCalls::Function(
150                                ChatCompletionMessageToolCall {
151                                    id: id.clone(),
152                                    function: FunctionCall {
153                                        name: name.clone(),
154                                        arguments: arguments.clone(),
155                                    },
156                                },
157                            )),
158                            _ => None,
159                        })
160                        .collect();
161
162                    let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
163                    if !content.is_empty() {
164                        builder.content(content);
165                    }
166                    if !tool_calls.is_empty() {
167                        builder.tool_calls(tool_calls);
168                    }
169                    result.push(builder.build()?.into());
170                }
171                Role::Tool => {
172                    for part in &msg.content {
173                        if let ContentPart::ToolResult {
174                            tool_call_id,
175                            content,
176                        } = part
177                        {
178                            result.push(
179                                ChatCompletionRequestToolMessageArgs::default()
180                                    .tool_call_id(tool_call_id.clone())
181                                    .content(content.clone())
182                                    .build()?
183                                    .into(),
184                            );
185                        }
186                    }
187                }
188            }
189        }
190
191        Ok(result)
192    }
193
194    fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
195        let mut result = Vec::new();
196        for tool in tools {
197            result.push(ChatCompletionTools::Function(ChatCompletionTool {
198                function: FunctionObjectArgs::default()
199                    .name(&tool.name)
200                    .description(&tool.description)
201                    .parameters(tool.parameters.clone())
202                    .build()?,
203            }));
204        }
205        Ok(result)
206    }
207}
208
209#[async_trait]
210impl Provider for OpenAIProvider {
211    fn name(&self) -> &str {
212        &self.provider_name
213    }
214
215    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
216        // For non-OpenAI providers, return provider-specific model defaults.
217        // Note: async-openai 0.32 does not expose a stable models list API across
218        // all OpenAI-compatible endpoints.
219        if self.provider_name != "openai" {
220            return Ok(self.provider_default_models());
221        }
222
223        // OpenAI default models
224        Ok(vec![
225            ModelInfo {
226                id: "gpt-4o".to_string(),
227                name: "GPT-4o".to_string(),
228                provider: "openai".to_string(),
229                context_window: 128_000,
230                max_output_tokens: Some(16_384),
231                supports_vision: true,
232                supports_tools: true,
233                supports_streaming: true,
234                input_cost_per_million: Some(2.5),
235                output_cost_per_million: Some(10.0),
236            },
237            ModelInfo {
238                id: "gpt-4o-mini".to_string(),
239                name: "GPT-4o Mini".to_string(),
240                provider: "openai".to_string(),
241                context_window: 128_000,
242                max_output_tokens: Some(16_384),
243                supports_vision: true,
244                supports_tools: true,
245                supports_streaming: true,
246                input_cost_per_million: Some(0.15),
247                output_cost_per_million: Some(0.6),
248            },
249            ModelInfo {
250                id: "o1".to_string(),
251                name: "o1".to_string(),
252                provider: "openai".to_string(),
253                context_window: 200_000,
254                max_output_tokens: Some(100_000),
255                supports_vision: true,
256                supports_tools: true,
257                supports_streaming: true,
258                input_cost_per_million: Some(15.0),
259                output_cost_per_million: Some(60.0),
260            },
261        ])
262    }
263
264    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
265        let messages = Self::convert_messages(&request.messages)?;
266        let tools = Self::convert_tools(&request.tools)?;
267
268        let mut req_builder = CreateChatCompletionRequestArgs::default();
269        req_builder.model(&request.model).messages(messages);
270
271        // Pass tools to the API if provided
272        if !tools.is_empty() {
273            req_builder.tools(tools);
274        }
275        if let Some(temp) = request.temperature {
276            req_builder.temperature(temp);
277        }
278        if let Some(max) = request.max_tokens {
279            req_builder.max_completion_tokens(max as u32);
280        }
281
282        let response = self.client.chat().create(req_builder.build()?).await?;
283
284        let choice = response
285            .choices
286            .first()
287            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
288
289        let mut content = Vec::new();
290        let mut has_tool_calls = false;
291
292        if let Some(text) = &choice.message.content {
293            content.push(ContentPart::Text { text: text.clone() });
294        }
295        if let Some(tool_calls) = &choice.message.tool_calls {
296            has_tool_calls = !tool_calls.is_empty();
297            for tc in tool_calls {
298                if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
299                    content.push(ContentPart::ToolCall {
300                        id: func_call.id.clone(),
301                        name: func_call.function.name.clone(),
302                        arguments: func_call.function.arguments.clone(),
303                    });
304                }
305            }
306        }
307
308        // Determine finish reason based on response
309        let finish_reason = if has_tool_calls {
310            FinishReason::ToolCalls
311        } else {
312            match choice.finish_reason {
313                Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
314                Some(OpenAIFinishReason::Length) => FinishReason::Length,
315                Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
316                Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
317                _ => FinishReason::Stop,
318            }
319        };
320
321        Ok(CompletionResponse {
322            message: Message {
323                role: Role::Assistant,
324                content,
325            },
326            usage: Usage {
327                prompt_tokens: response
328                    .usage
329                    .as_ref()
330                    .map(|u| u.prompt_tokens as usize)
331                    .unwrap_or(0),
332                completion_tokens: response
333                    .usage
334                    .as_ref()
335                    .map(|u| u.completion_tokens as usize)
336                    .unwrap_or(0),
337                total_tokens: response
338                    .usage
339                    .as_ref()
340                    .map(|u| u.total_tokens as usize)
341                    .unwrap_or(0),
342                ..Default::default()
343            },
344            finish_reason,
345        })
346    }
347
348    async fn complete_stream(
349        &self,
350        request: CompletionRequest,
351    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
352        tracing::debug!(
353            provider = %self.provider_name,
354            model = %request.model,
355            message_count = request.messages.len(),
356            "Starting streaming completion request"
357        );
358
359        let messages = Self::convert_messages(&request.messages)?;
360
361        let mut req_builder = CreateChatCompletionRequestArgs::default();
362        req_builder
363            .model(&request.model)
364            .messages(messages)
365            .stream(true);
366
367        if let Some(temp) = request.temperature {
368            req_builder.temperature(temp);
369        }
370
371        let stream = self
372            .client
373            .chat()
374            .create_stream(req_builder.build()?)
375            .await?;
376
377        Ok(stream
378            .map(|result| match result {
379                Ok(response) => {
380                    if let Some(choice) = response.choices.first() {
381                        if let Some(content) = &choice.delta.content {
382                            return StreamChunk::Text(content.clone());
383                        }
384                    }
385                    StreamChunk::Text(String::new())
386                }
387                Err(e) => StreamChunk::Error(e.to_string()),
388            })
389            .boxed())
390    }
391}