Skip to main content

codetether_agent/provider/
openai.rs

1//! OpenAI provider implementation
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9    Client,
10    config::OpenAIConfig,
11    types::chat::{
12        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16        CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17        FunctionObjectArgs,
18    },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24    client: Client<OpenAIConfig>,
25    provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("OpenAIProvider")
31            .field("provider_name", &self.provider_name)
32            .field("client", &"<async_openai::Client>")
33            .finish()
34    }
35}
36
37impl OpenAIProvider {
38    pub fn new(api_key: String) -> Result<Self> {
39        tracing::debug!(
40            provider = "openai",
41            api_key_len = api_key.len(),
42            "Creating OpenAI provider"
43        );
44        let config = OpenAIConfig::new().with_api_key(api_key);
45        Ok(Self {
46            client: Client::with_config(config),
47            provider_name: "openai".to_string(),
48        })
49    }
50
51    /// Create with custom base URL (for OpenAI-compatible providers like Moonshot)
52    pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53        tracing::debug!(
54            provider = provider_name,
55            base_url = %base_url,
56            api_key_len = api_key.len(),
57            "Creating OpenAI-compatible provider"
58        );
59        let config = OpenAIConfig::new()
60            .with_api_key(api_key)
61            .with_api_base(base_url);
62        Ok(Self {
63            client: Client::with_config(config),
64            provider_name: provider_name.to_string(),
65        })
66    }
67
68    /// Return known models for specific OpenAI-compatible providers
69    fn provider_default_models(&self) -> Vec<ModelInfo> {
70        let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71            "cerebras" => vec![
72                ("llama3.1-8b", "Llama 3.1 8B"),
73                ("llama-3.3-70b", "Llama 3.3 70B"),
74                ("qwen-3-32b", "Qwen 3 32B"),
75                ("gpt-oss-120b", "GPT-OSS 120B"),
76            ],
77
78            "minimax" => vec![
79                ("MiniMax-M2.5", "MiniMax M2.5"),
80                ("MiniMax-M2.5-highspeed", "MiniMax M2.5 Highspeed"),
81                ("MiniMax-M2.1", "MiniMax M2.1"),
82                ("MiniMax-M2.1-highspeed", "MiniMax M2.1 Highspeed"),
83                ("MiniMax-M2", "MiniMax M2"),
84            ],
85            "zhipuai" => vec![],
86            "novita" => vec![
87                ("qwen/qwen3-coder-next", "Qwen 3 Coder Next"),
88                ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
89                ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
90                ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
91            ],
92            _ => vec![],
93        };
94
95        models
96            .into_iter()
97            .map(|(id, name)| ModelInfo {
98                id: id.to_string(),
99                name: name.to_string(),
100                provider: self.provider_name.clone(),
101                context_window: 128_000,
102                max_output_tokens: Some(16_384),
103                supports_vision: false,
104                supports_tools: true,
105                supports_streaming: true,
106                input_cost_per_million: None,
107                output_cost_per_million: None,
108            })
109            .collect()
110    }
111
112    fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
113        let mut result = Vec::new();
114
115        for msg in messages {
116            let content = msg
117                .content
118                .iter()
119                .filter_map(|p| match p {
120                    ContentPart::Text { text } => Some(text.clone()),
121                    _ => None,
122                })
123                .collect::<Vec<_>>()
124                .join("\n");
125
126            match msg.role {
127                Role::System => {
128                    result.push(
129                        ChatCompletionRequestSystemMessageArgs::default()
130                            .content(content)
131                            .build()?
132                            .into(),
133                    );
134                }
135                Role::User => {
136                    result.push(
137                        ChatCompletionRequestUserMessageArgs::default()
138                            .content(content)
139                            .build()?
140                            .into(),
141                    );
142                }
143                Role::Assistant => {
144                    let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
145                        .content
146                        .iter()
147                        .filter_map(|p| match p {
148                            ContentPart::ToolCall {
149                                id,
150                                name,
151                                arguments,
152                                ..
153                            } => Some(ChatCompletionMessageToolCalls::Function(
154                                ChatCompletionMessageToolCall {
155                                    id: id.clone(),
156                                    function: FunctionCall {
157                                        name: name.clone(),
158                                        arguments: arguments.clone(),
159                                    },
160                                },
161                            )),
162                            _ => None,
163                        })
164                        .collect();
165
166                    let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
167                    if !content.is_empty() {
168                        builder.content(content);
169                    }
170                    if !tool_calls.is_empty() {
171                        builder.tool_calls(tool_calls);
172                    }
173                    result.push(builder.build()?.into());
174                }
175                Role::Tool => {
176                    for part in &msg.content {
177                        if let ContentPart::ToolResult {
178                            tool_call_id,
179                            content,
180                        } = part
181                        {
182                            result.push(
183                                ChatCompletionRequestToolMessageArgs::default()
184                                    .tool_call_id(tool_call_id.clone())
185                                    .content(content.clone())
186                                    .build()?
187                                    .into(),
188                            );
189                        }
190                    }
191                }
192            }
193        }
194
195        Ok(result)
196    }
197
198    fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
199        let mut result = Vec::new();
200        for tool in tools {
201            result.push(ChatCompletionTools::Function(ChatCompletionTool {
202                function: FunctionObjectArgs::default()
203                    .name(&tool.name)
204                    .description(&tool.description)
205                    .parameters(tool.parameters.clone())
206                    .build()?,
207            }));
208        }
209        Ok(result)
210    }
211
212    fn is_minimax_chat_setting_error(error: &str) -> bool {
213        let normalized = error.to_ascii_lowercase();
214        normalized.contains("invalid chat setting")
215            || normalized.contains("(2013)")
216            || normalized.contains("code: 2013")
217            || normalized.contains("\"2013\"")
218    }
219}
220
221#[async_trait]
222impl Provider for OpenAIProvider {
223    fn name(&self) -> &str {
224        &self.provider_name
225    }
226
227    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
228        // For non-OpenAI providers, return provider-specific model defaults.
229        // Note: async-openai 0.32 does not expose a stable models list API across
230        // all OpenAI-compatible endpoints.
231        if self.provider_name != "openai" {
232            return Ok(self.provider_default_models());
233        }
234
235        // OpenAI default models
236        Ok(vec![
237            ModelInfo {
238                id: "gpt-4o".to_string(),
239                name: "GPT-4o".to_string(),
240                provider: "openai".to_string(),
241                context_window: 128_000,
242                max_output_tokens: Some(16_384),
243                supports_vision: true,
244                supports_tools: true,
245                supports_streaming: true,
246                input_cost_per_million: Some(2.5),
247                output_cost_per_million: Some(10.0),
248            },
249            ModelInfo {
250                id: "gpt-4o-mini".to_string(),
251                name: "GPT-4o Mini".to_string(),
252                provider: "openai".to_string(),
253                context_window: 128_000,
254                max_output_tokens: Some(16_384),
255                supports_vision: true,
256                supports_tools: true,
257                supports_streaming: true,
258                input_cost_per_million: Some(0.15),
259                output_cost_per_million: Some(0.6),
260            },
261            ModelInfo {
262                id: "o1".to_string(),
263                name: "o1".to_string(),
264                provider: "openai".to_string(),
265                context_window: 200_000,
266                max_output_tokens: Some(100_000),
267                supports_vision: true,
268                supports_tools: true,
269                supports_streaming: true,
270                input_cost_per_million: Some(15.0),
271                output_cost_per_million: Some(60.0),
272            },
273        ])
274    }
275
276    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
277        let messages = Self::convert_messages(&request.messages)?;
278        let tools = Self::convert_tools(&request.tools)?;
279
280        let mut req_builder = CreateChatCompletionRequestArgs::default();
281        req_builder.model(&request.model).messages(messages.clone());
282
283        // Pass tools to the API if provided
284        if !tools.is_empty() {
285            req_builder.tools(tools);
286        }
287        if let Some(temp) = request.temperature {
288            req_builder.temperature(temp);
289        }
290        if let Some(top_p) = request.top_p {
291            req_builder.top_p(top_p);
292        }
293        if let Some(max) = request.max_tokens {
294            if self.provider_name == "openai" {
295                req_builder.max_completion_tokens(max as u32);
296            } else {
297                req_builder.max_tokens(max as u32);
298            }
299        }
300
301        let primary_request = req_builder.build()?;
302        let response = match self.client.chat().create(primary_request).await {
303            Ok(response) => response,
304            Err(err)
305                if self.provider_name == "minimax"
306                    && Self::is_minimax_chat_setting_error(&err.to_string()) =>
307            {
308                tracing::warn!(
309                    provider = "minimax",
310                    error = %err,
311                    "MiniMax rejected chat settings; retrying with conservative defaults"
312                );
313
314                let mut fallback_builder = CreateChatCompletionRequestArgs::default();
315                fallback_builder.model(&request.model).messages(messages);
316                self.client.chat().create(fallback_builder.build()?).await?
317            }
318            Err(err) => return Err(err.into()),
319        };
320
321        let choice = response
322            .choices
323            .first()
324            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
325
326        let mut content = Vec::new();
327        let mut has_tool_calls = false;
328
329        if let Some(text) = &choice.message.content {
330            content.push(ContentPart::Text { text: text.clone() });
331        }
332        if let Some(tool_calls) = &choice.message.tool_calls {
333            has_tool_calls = !tool_calls.is_empty();
334            for tc in tool_calls {
335                if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
336                    content.push(ContentPart::ToolCall {
337                        id: func_call.id.clone(),
338                        name: func_call.function.name.clone(),
339                        arguments: func_call.function.arguments.clone(),
340                        thought_signature: None,
341                    });
342                }
343            }
344        }
345
346        // Determine finish reason based on response
347        let finish_reason = if has_tool_calls {
348            FinishReason::ToolCalls
349        } else {
350            match choice.finish_reason {
351                Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
352                Some(OpenAIFinishReason::Length) => FinishReason::Length,
353                Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
354                Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
355                _ => FinishReason::Stop,
356            }
357        };
358
359        Ok(CompletionResponse {
360            message: Message {
361                role: Role::Assistant,
362                content,
363            },
364            usage: Usage {
365                prompt_tokens: response
366                    .usage
367                    .as_ref()
368                    .map(|u| u.prompt_tokens as usize)
369                    .unwrap_or(0),
370                completion_tokens: response
371                    .usage
372                    .as_ref()
373                    .map(|u| u.completion_tokens as usize)
374                    .unwrap_or(0),
375                total_tokens: response
376                    .usage
377                    .as_ref()
378                    .map(|u| u.total_tokens as usize)
379                    .unwrap_or(0),
380                ..Default::default()
381            },
382            finish_reason,
383        })
384    }
385
386    async fn complete_stream(
387        &self,
388        request: CompletionRequest,
389    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
390        tracing::debug!(
391            provider = %self.provider_name,
392            model = %request.model,
393            message_count = request.messages.len(),
394            "Starting streaming completion request"
395        );
396
397        let messages = Self::convert_messages(&request.messages)?;
398
399        let mut req_builder = CreateChatCompletionRequestArgs::default();
400        req_builder
401            .model(&request.model)
402            .messages(messages)
403            .stream(true);
404
405        if let Some(temp) = request.temperature {
406            req_builder.temperature(temp);
407        }
408
409        let stream = self
410            .client
411            .chat()
412            .create_stream(req_builder.build()?)
413            .await?;
414
415        Ok(stream
416            .map(|result| match result {
417                Ok(response) => {
418                    if let Some(choice) = response.choices.first() {
419                        if let Some(content) = &choice.delta.content {
420                            return StreamChunk::Text(content.clone());
421                        }
422                    }
423                    StreamChunk::Text(String::new())
424                }
425                Err(e) => StreamChunk::Error(e.to_string()),
426            })
427            .boxed())
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::OpenAIProvider;
434
435    #[test]
436    fn detects_minimax_chat_setting_error_variants() {
437        assert!(OpenAIProvider::is_minimax_chat_setting_error(
438            "bad_request_error: invalid params, invalid chat setting (2013)"
439        ));
440        assert!(OpenAIProvider::is_minimax_chat_setting_error(
441            "code: 2013 invalid params"
442        ));
443        assert!(!OpenAIProvider::is_minimax_chat_setting_error(
444            "rate limit exceeded"
445        ));
446    }
447}