Skip to main content

codetether_agent/provider/
openai.rs

1//! OpenAI provider implementation
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9    Client,
10    config::OpenAIConfig,
11    types::chat::{
12        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16        CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17        FunctionObjectArgs,
18    },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24    client: Client<OpenAIConfig>,
25    provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("OpenAIProvider")
31            .field("provider_name", &self.provider_name)
32            .field("client", &"<async_openai::Client>")
33            .finish()
34    }
35}
36
37impl OpenAIProvider {
38    pub fn new(api_key: String) -> Result<Self> {
39        tracing::debug!(
40            provider = "openai",
41            api_key_len = api_key.len(),
42            "Creating OpenAI provider"
43        );
44        let config = OpenAIConfig::new().with_api_key(api_key);
45        Ok(Self {
46            client: Client::with_config(config),
47            provider_name: "openai".to_string(),
48        })
49    }
50
51    /// Create with custom base URL (for OpenAI-compatible providers like Moonshot)
52    pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53        tracing::debug!(
54            provider = provider_name,
55            base_url = %base_url,
56            api_key_len = api_key.len(),
57            "Creating OpenAI-compatible provider"
58        );
59        let config = OpenAIConfig::new()
60            .with_api_key(api_key)
61            .with_api_base(base_url);
62        Ok(Self {
63            client: Client::with_config(config),
64            provider_name: provider_name.to_string(),
65        })
66    }
67
68    /// Return known models for specific OpenAI-compatible providers
69    fn provider_default_models(&self) -> Vec<ModelInfo> {
70        let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71            "cerebras" => vec![
72                ("llama3.1-8b", "Llama 3.1 8B"),
73                ("llama-3.3-70b", "Llama 3.3 70B"),
74                ("qwen-3-32b", "Qwen 3 32B"),
75                ("gpt-oss-120b", "GPT-OSS 120B"),
76            ],
77
78            "minimax" => vec![
79                ("MiniMax-M1-80k", "MiniMax M1 80k"),
80                ("MiniMax-Text-01", "MiniMax Text 01"),
81            ],
82            "zhipuai" => vec![
83                ("glm-4.7", "GLM-4.7"),
84                ("glm-4-plus", "GLM-4 Plus"),
85                ("glm-4-flash", "GLM-4 Flash"),
86                ("glm-4-long", "GLM-4 Long"),
87            ],
88            "novita" => vec![
89                ("qwen/qwen3-coder-next", "Qwen 3 Coder Next"),
90                ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
91                ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
92                ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
93            ],
94            _ => vec![],
95        };
96
97        models
98            .into_iter()
99            .map(|(id, name)| ModelInfo {
100                id: id.to_string(),
101                name: name.to_string(),
102                provider: self.provider_name.clone(),
103                context_window: 128_000,
104                max_output_tokens: Some(16_384),
105                supports_vision: false,
106                supports_tools: true,
107                supports_streaming: true,
108                input_cost_per_million: None,
109                output_cost_per_million: None,
110            })
111            .collect()
112    }
113
114    fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
115        let mut result = Vec::new();
116
117        for msg in messages {
118            let content = msg
119                .content
120                .iter()
121                .filter_map(|p| match p {
122                    ContentPart::Text { text } => Some(text.clone()),
123                    _ => None,
124                })
125                .collect::<Vec<_>>()
126                .join("\n");
127
128            match msg.role {
129                Role::System => {
130                    result.push(
131                        ChatCompletionRequestSystemMessageArgs::default()
132                            .content(content)
133                            .build()?
134                            .into(),
135                    );
136                }
137                Role::User => {
138                    result.push(
139                        ChatCompletionRequestUserMessageArgs::default()
140                            .content(content)
141                            .build()?
142                            .into(),
143                    );
144                }
145                Role::Assistant => {
146                    let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
147                        .content
148                        .iter()
149                        .filter_map(|p| match p {
150                            ContentPart::ToolCall {
151                                id,
152                                name,
153                                arguments,
154                            } => Some(ChatCompletionMessageToolCalls::Function(
155                                ChatCompletionMessageToolCall {
156                                    id: id.clone(),
157                                    function: FunctionCall {
158                                        name: name.clone(),
159                                        arguments: arguments.clone(),
160                                    },
161                                },
162                            )),
163                            _ => None,
164                        })
165                        .collect();
166
167                    let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
168                    if !content.is_empty() {
169                        builder.content(content);
170                    }
171                    if !tool_calls.is_empty() {
172                        builder.tool_calls(tool_calls);
173                    }
174                    result.push(builder.build()?.into());
175                }
176                Role::Tool => {
177                    for part in &msg.content {
178                        if let ContentPart::ToolResult {
179                            tool_call_id,
180                            content,
181                        } = part
182                        {
183                            result.push(
184                                ChatCompletionRequestToolMessageArgs::default()
185                                    .tool_call_id(tool_call_id.clone())
186                                    .content(content.clone())
187                                    .build()?
188                                    .into(),
189                            );
190                        }
191                    }
192                }
193            }
194        }
195
196        Ok(result)
197    }
198
199    fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
200        let mut result = Vec::new();
201        for tool in tools {
202            result.push(ChatCompletionTools::Function(ChatCompletionTool {
203                function: FunctionObjectArgs::default()
204                    .name(&tool.name)
205                    .description(&tool.description)
206                    .parameters(tool.parameters.clone())
207                    .build()?,
208            }));
209        }
210        Ok(result)
211    }
212}
213
214#[async_trait]
215impl Provider for OpenAIProvider {
216    fn name(&self) -> &str {
217        &self.provider_name
218    }
219
220    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
221        // For non-OpenAI providers, return provider-specific model defaults.
222        // Note: async-openai 0.32 does not expose a stable models list API across
223        // all OpenAI-compatible endpoints.
224        if self.provider_name != "openai" {
225            return Ok(self.provider_default_models());
226        }
227
228        // OpenAI default models
229        Ok(vec![
230            ModelInfo {
231                id: "gpt-4o".to_string(),
232                name: "GPT-4o".to_string(),
233                provider: "openai".to_string(),
234                context_window: 128_000,
235                max_output_tokens: Some(16_384),
236                supports_vision: true,
237                supports_tools: true,
238                supports_streaming: true,
239                input_cost_per_million: Some(2.5),
240                output_cost_per_million: Some(10.0),
241            },
242            ModelInfo {
243                id: "gpt-4o-mini".to_string(),
244                name: "GPT-4o Mini".to_string(),
245                provider: "openai".to_string(),
246                context_window: 128_000,
247                max_output_tokens: Some(16_384),
248                supports_vision: true,
249                supports_tools: true,
250                supports_streaming: true,
251                input_cost_per_million: Some(0.15),
252                output_cost_per_million: Some(0.6),
253            },
254            ModelInfo {
255                id: "o1".to_string(),
256                name: "o1".to_string(),
257                provider: "openai".to_string(),
258                context_window: 200_000,
259                max_output_tokens: Some(100_000),
260                supports_vision: true,
261                supports_tools: true,
262                supports_streaming: true,
263                input_cost_per_million: Some(15.0),
264                output_cost_per_million: Some(60.0),
265            },
266        ])
267    }
268
269    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
270        let messages = Self::convert_messages(&request.messages)?;
271        let tools = Self::convert_tools(&request.tools)?;
272
273        let mut req_builder = CreateChatCompletionRequestArgs::default();
274        req_builder.model(&request.model).messages(messages);
275
276        // Pass tools to the API if provided
277        if !tools.is_empty() {
278            req_builder.tools(tools);
279        }
280        if let Some(temp) = request.temperature {
281            req_builder.temperature(temp);
282        }
283        if let Some(max) = request.max_tokens {
284            req_builder.max_completion_tokens(max as u32);
285        }
286
287        let response = self.client.chat().create(req_builder.build()?).await?;
288
289        let choice = response
290            .choices
291            .first()
292            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
293
294        let mut content = Vec::new();
295        let mut has_tool_calls = false;
296
297        if let Some(text) = &choice.message.content {
298            content.push(ContentPart::Text { text: text.clone() });
299        }
300        if let Some(tool_calls) = &choice.message.tool_calls {
301            has_tool_calls = !tool_calls.is_empty();
302            for tc in tool_calls {
303                if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
304                    content.push(ContentPart::ToolCall {
305                        id: func_call.id.clone(),
306                        name: func_call.function.name.clone(),
307                        arguments: func_call.function.arguments.clone(),
308                    });
309                }
310            }
311        }
312
313        // Determine finish reason based on response
314        let finish_reason = if has_tool_calls {
315            FinishReason::ToolCalls
316        } else {
317            match choice.finish_reason {
318                Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
319                Some(OpenAIFinishReason::Length) => FinishReason::Length,
320                Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
321                Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
322                _ => FinishReason::Stop,
323            }
324        };
325
326        Ok(CompletionResponse {
327            message: Message {
328                role: Role::Assistant,
329                content,
330            },
331            usage: Usage {
332                prompt_tokens: response
333                    .usage
334                    .as_ref()
335                    .map(|u| u.prompt_tokens as usize)
336                    .unwrap_or(0),
337                completion_tokens: response
338                    .usage
339                    .as_ref()
340                    .map(|u| u.completion_tokens as usize)
341                    .unwrap_or(0),
342                total_tokens: response
343                    .usage
344                    .as_ref()
345                    .map(|u| u.total_tokens as usize)
346                    .unwrap_or(0),
347                ..Default::default()
348            },
349            finish_reason,
350        })
351    }
352
353    async fn complete_stream(
354        &self,
355        request: CompletionRequest,
356    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
357        tracing::debug!(
358            provider = %self.provider_name,
359            model = %request.model,
360            message_count = request.messages.len(),
361            "Starting streaming completion request"
362        );
363
364        let messages = Self::convert_messages(&request.messages)?;
365
366        let mut req_builder = CreateChatCompletionRequestArgs::default();
367        req_builder
368            .model(&request.model)
369            .messages(messages)
370            .stream(true);
371
372        if let Some(temp) = request.temperature {
373            req_builder.temperature(temp);
374        }
375
376        let stream = self
377            .client
378            .chat()
379            .create_stream(req_builder.build()?)
380            .await?;
381
382        Ok(stream
383            .map(|result| match result {
384                Ok(response) => {
385                    if let Some(choice) = response.choices.first() {
386                        if let Some(content) = &choice.delta.content {
387                            return StreamChunk::Text(content.clone());
388                        }
389                    }
390                    StreamChunk::Text(String::new())
391                }
392                Err(e) => StreamChunk::Error(e.to_string()),
393            })
394            .boxed())
395    }
396}