Skip to main content

codetether_agent/provider/
openai.rs

1//! OpenAI provider implementation
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9    Client,
10    config::OpenAIConfig,
11    types::chat::{
12        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16        CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17        FunctionObjectArgs,
18    },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24    client: Client<OpenAIConfig>,
25    provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("OpenAIProvider")
31            .field("provider_name", &self.provider_name)
32            .field("client", &"<async_openai::Client>")
33            .finish()
34    }
35}
36
37impl OpenAIProvider {
38    pub fn new(api_key: String) -> Result<Self> {
39        tracing::debug!(
40            provider = "openai",
41            api_key_len = api_key.len(),
42            "Creating OpenAI provider"
43        );
44        let config = OpenAIConfig::new().with_api_key(api_key);
45        Ok(Self {
46            client: Client::with_config(config),
47            provider_name: "openai".to_string(),
48        })
49    }
50
51    /// Create with custom base URL (for OpenAI-compatible providers like Moonshot)
52    pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53        tracing::debug!(
54            provider = provider_name,
55            base_url = %base_url,
56            api_key_len = api_key.len(),
57            "Creating OpenAI-compatible provider"
58        );
59        let config = OpenAIConfig::new()
60            .with_api_key(api_key)
61            .with_api_base(base_url);
62        Ok(Self {
63            client: Client::with_config(config),
64            provider_name: provider_name.to_string(),
65        })
66    }
67
68    /// Return known models for specific OpenAI-compatible providers
69    fn provider_default_models(&self) -> Vec<ModelInfo> {
70        let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71            "cerebras" => vec![
72                ("llama3.1-8b", "Llama 3.1 8B"),
73                ("llama-3.3-70b", "Llama 3.3 70B"),
74                ("qwen-3-32b", "Qwen 3 32B"),
75                ("gpt-oss-120b", "GPT-OSS 120B"),
76            ],
77
78            "minimax" => vec![
79                ("MiniMax-M2.5", "MiniMax M2.5"),
80                ("MiniMax-M2.5-lightning", "MiniMax M2.5 Lightning"),
81                ("MiniMax-M2.1", "MiniMax M2.1"),
82                ("MiniMax-M2.1-lightning", "MiniMax M2.1 Lightning"),
83                ("MiniMax-M2", "MiniMax M2"),
84            ],
85            "zhipuai" => vec![],
86            "novita" => vec![
87                ("qwen/qwen3-coder-next", "Qwen 3 Coder Next"),
88                ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
89                ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
90                ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
91            ],
92            _ => vec![],
93        };
94
95        models
96            .into_iter()
97            .map(|(id, name)| ModelInfo {
98                id: id.to_string(),
99                name: name.to_string(),
100                provider: self.provider_name.clone(),
101                context_window: 128_000,
102                max_output_tokens: Some(16_384),
103                supports_vision: false,
104                supports_tools: true,
105                supports_streaming: true,
106                input_cost_per_million: None,
107                output_cost_per_million: None,
108            })
109            .collect()
110    }
111
112    fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
113        let mut result = Vec::new();
114
115        for msg in messages {
116            let content = msg
117                .content
118                .iter()
119                .filter_map(|p| match p {
120                    ContentPart::Text { text } => Some(text.clone()),
121                    _ => None,
122                })
123                .collect::<Vec<_>>()
124                .join("\n");
125
126            match msg.role {
127                Role::System => {
128                    result.push(
129                        ChatCompletionRequestSystemMessageArgs::default()
130                            .content(content)
131                            .build()?
132                            .into(),
133                    );
134                }
135                Role::User => {
136                    result.push(
137                        ChatCompletionRequestUserMessageArgs::default()
138                            .content(content)
139                            .build()?
140                            .into(),
141                    );
142                }
143                Role::Assistant => {
144                    let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
145                        .content
146                        .iter()
147                        .filter_map(|p| match p {
148                            ContentPart::ToolCall {
149                                id,
150                                name,
151                                arguments,
152                            } => Some(ChatCompletionMessageToolCalls::Function(
153                                ChatCompletionMessageToolCall {
154                                    id: id.clone(),
155                                    function: FunctionCall {
156                                        name: name.clone(),
157                                        arguments: arguments.clone(),
158                                    },
159                                },
160                            )),
161                            _ => None,
162                        })
163                        .collect();
164
165                    let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
166                    if !content.is_empty() {
167                        builder.content(content);
168                    }
169                    if !tool_calls.is_empty() {
170                        builder.tool_calls(tool_calls);
171                    }
172                    result.push(builder.build()?.into());
173                }
174                Role::Tool => {
175                    for part in &msg.content {
176                        if let ContentPart::ToolResult {
177                            tool_call_id,
178                            content,
179                        } = part
180                        {
181                            result.push(
182                                ChatCompletionRequestToolMessageArgs::default()
183                                    .tool_call_id(tool_call_id.clone())
184                                    .content(content.clone())
185                                    .build()?
186                                    .into(),
187                            );
188                        }
189                    }
190                }
191            }
192        }
193
194        Ok(result)
195    }
196
197    fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
198        let mut result = Vec::new();
199        for tool in tools {
200            result.push(ChatCompletionTools::Function(ChatCompletionTool {
201                function: FunctionObjectArgs::default()
202                    .name(&tool.name)
203                    .description(&tool.description)
204                    .parameters(tool.parameters.clone())
205                    .build()?,
206            }));
207        }
208        Ok(result)
209    }
210
211    fn is_minimax_chat_setting_error(error: &str) -> bool {
212        let normalized = error.to_ascii_lowercase();
213        normalized.contains("invalid chat setting")
214            || normalized.contains("(2013)")
215            || normalized.contains("code: 2013")
216            || normalized.contains("\"2013\"")
217    }
218}
219
220#[async_trait]
221impl Provider for OpenAIProvider {
222    fn name(&self) -> &str {
223        &self.provider_name
224    }
225
226    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
227        // For non-OpenAI providers, return provider-specific model defaults.
228        // Note: async-openai 0.32 does not expose a stable models list API across
229        // all OpenAI-compatible endpoints.
230        if self.provider_name != "openai" {
231            return Ok(self.provider_default_models());
232        }
233
234        // OpenAI default models
235        Ok(vec![
236            ModelInfo {
237                id: "gpt-4o".to_string(),
238                name: "GPT-4o".to_string(),
239                provider: "openai".to_string(),
240                context_window: 128_000,
241                max_output_tokens: Some(16_384),
242                supports_vision: true,
243                supports_tools: true,
244                supports_streaming: true,
245                input_cost_per_million: Some(2.5),
246                output_cost_per_million: Some(10.0),
247            },
248            ModelInfo {
249                id: "gpt-4o-mini".to_string(),
250                name: "GPT-4o Mini".to_string(),
251                provider: "openai".to_string(),
252                context_window: 128_000,
253                max_output_tokens: Some(16_384),
254                supports_vision: true,
255                supports_tools: true,
256                supports_streaming: true,
257                input_cost_per_million: Some(0.15),
258                output_cost_per_million: Some(0.6),
259            },
260            ModelInfo {
261                id: "o1".to_string(),
262                name: "o1".to_string(),
263                provider: "openai".to_string(),
264                context_window: 200_000,
265                max_output_tokens: Some(100_000),
266                supports_vision: true,
267                supports_tools: true,
268                supports_streaming: true,
269                input_cost_per_million: Some(15.0),
270                output_cost_per_million: Some(60.0),
271            },
272        ])
273    }
274
275    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
276        let messages = Self::convert_messages(&request.messages)?;
277        let tools = Self::convert_tools(&request.tools)?;
278
279        let mut req_builder = CreateChatCompletionRequestArgs::default();
280        req_builder.model(&request.model).messages(messages.clone());
281
282        // Pass tools to the API if provided
283        if !tools.is_empty() {
284            req_builder.tools(tools);
285        }
286        if let Some(temp) = request.temperature {
287            req_builder.temperature(temp);
288        }
289        if let Some(top_p) = request.top_p {
290            req_builder.top_p(top_p);
291        }
292        if let Some(max) = request.max_tokens {
293            if self.provider_name == "openai" {
294                req_builder.max_completion_tokens(max as u32);
295            } else {
296                req_builder.max_tokens(max as u32);
297            }
298        }
299
300        let primary_request = req_builder.build()?;
301        let response = match self.client.chat().create(primary_request).await {
302            Ok(response) => response,
303            Err(err)
304                if self.provider_name == "minimax"
305                    && Self::is_minimax_chat_setting_error(&err.to_string()) =>
306            {
307                tracing::warn!(
308                    provider = "minimax",
309                    error = %err,
310                    "MiniMax rejected chat settings; retrying with conservative defaults"
311                );
312
313                let mut fallback_builder = CreateChatCompletionRequestArgs::default();
314                fallback_builder.model(&request.model).messages(messages);
315                self.client.chat().create(fallback_builder.build()?).await?
316            }
317            Err(err) => return Err(err.into()),
318        };
319
320        let choice = response
321            .choices
322            .first()
323            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
324
325        let mut content = Vec::new();
326        let mut has_tool_calls = false;
327
328        if let Some(text) = &choice.message.content {
329            content.push(ContentPart::Text { text: text.clone() });
330        }
331        if let Some(tool_calls) = &choice.message.tool_calls {
332            has_tool_calls = !tool_calls.is_empty();
333            for tc in tool_calls {
334                if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
335                    content.push(ContentPart::ToolCall {
336                        id: func_call.id.clone(),
337                        name: func_call.function.name.clone(),
338                        arguments: func_call.function.arguments.clone(),
339                    });
340                }
341            }
342        }
343
344        // Determine finish reason based on response
345        let finish_reason = if has_tool_calls {
346            FinishReason::ToolCalls
347        } else {
348            match choice.finish_reason {
349                Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
350                Some(OpenAIFinishReason::Length) => FinishReason::Length,
351                Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
352                Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
353                _ => FinishReason::Stop,
354            }
355        };
356
357        Ok(CompletionResponse {
358            message: Message {
359                role: Role::Assistant,
360                content,
361            },
362            usage: Usage {
363                prompt_tokens: response
364                    .usage
365                    .as_ref()
366                    .map(|u| u.prompt_tokens as usize)
367                    .unwrap_or(0),
368                completion_tokens: response
369                    .usage
370                    .as_ref()
371                    .map(|u| u.completion_tokens as usize)
372                    .unwrap_or(0),
373                total_tokens: response
374                    .usage
375                    .as_ref()
376                    .map(|u| u.total_tokens as usize)
377                    .unwrap_or(0),
378                ..Default::default()
379            },
380            finish_reason,
381        })
382    }
383
384    async fn complete_stream(
385        &self,
386        request: CompletionRequest,
387    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
388        tracing::debug!(
389            provider = %self.provider_name,
390            model = %request.model,
391            message_count = request.messages.len(),
392            "Starting streaming completion request"
393        );
394
395        let messages = Self::convert_messages(&request.messages)?;
396
397        let mut req_builder = CreateChatCompletionRequestArgs::default();
398        req_builder
399            .model(&request.model)
400            .messages(messages)
401            .stream(true);
402
403        if let Some(temp) = request.temperature {
404            req_builder.temperature(temp);
405        }
406
407        let stream = self
408            .client
409            .chat()
410            .create_stream(req_builder.build()?)
411            .await?;
412
413        Ok(stream
414            .map(|result| match result {
415                Ok(response) => {
416                    if let Some(choice) = response.choices.first() {
417                        if let Some(content) = &choice.delta.content {
418                            return StreamChunk::Text(content.clone());
419                        }
420                    }
421                    StreamChunk::Text(String::new())
422                }
423                Err(e) => StreamChunk::Error(e.to_string()),
424            })
425            .boxed())
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::OpenAIProvider;
432
433    #[test]
434    fn detects_minimax_chat_setting_error_variants() {
435        assert!(OpenAIProvider::is_minimax_chat_setting_error(
436            "bad_request_error: invalid params, invalid chat setting (2013)"
437        ));
438        assert!(OpenAIProvider::is_minimax_chat_setting_error(
439            "code: 2013 invalid params"
440        ));
441        assert!(!OpenAIProvider::is_minimax_chat_setting_error(
442            "rate limit exceeded"
443        ));
444    }
445}