Skip to main content

codetether_agent/provider/
openai.rs

1//! OpenAI provider implementation
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9    Client,
10    config::OpenAIConfig,
11    types::chat::{
12        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16        CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17        FunctionObjectArgs,
18    },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24    client: Client<OpenAIConfig>,
25    provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("OpenAIProvider")
31            .field("provider_name", &self.provider_name)
32            .field("client", &"<async_openai::Client>")
33            .finish()
34    }
35}
36
37impl OpenAIProvider {
38    pub fn new(api_key: String) -> Result<Self> {
39        tracing::debug!(
40            provider = "openai",
41            api_key_len = api_key.len(),
42            "Creating OpenAI provider"
43        );
44        let config = OpenAIConfig::new().with_api_key(api_key);
45        Ok(Self {
46            client: Client::with_config(config),
47            provider_name: "openai".to_string(),
48        })
49    }
50
51    /// Create with custom base URL (for OpenAI-compatible providers like Moonshot)
52    pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53        tracing::debug!(
54            provider = provider_name,
55            base_url = %base_url,
56            api_key_len = api_key.len(),
57            "Creating OpenAI-compatible provider"
58        );
59        let config = OpenAIConfig::new()
60            .with_api_key(api_key)
61            .with_api_base(base_url);
62        Ok(Self {
63            client: Client::with_config(config),
64            provider_name: provider_name.to_string(),
65        })
66    }
67
68    /// Return known models for specific OpenAI-compatible providers
69    fn provider_default_models(&self) -> Vec<ModelInfo> {
70        let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71            "cerebras" => vec![
72                ("llama3.1-8b", "Llama 3.1 8B"),
73                ("llama-3.3-70b", "Llama 3.3 70B"),
74                ("qwen-3-32b", "Qwen 3 32B"),
75                ("gpt-oss-120b", "GPT-OSS 120B"),
76            ],
77            "novita" => vec![
78                ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
79                ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
80                ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
81                ("qwen/qwen-2.5-72b-instruct", "Qwen 2.5 72B"),
82            ],
83            "minimax" => vec![
84                ("MiniMax-M1-80k", "MiniMax M1 80k"),
85                ("MiniMax-Text-01", "MiniMax Text 01"),
86            ],
87            _ => vec![],
88        };
89
90        models
91            .into_iter()
92            .map(|(id, name)| ModelInfo {
93                id: id.to_string(),
94                name: name.to_string(),
95                provider: self.provider_name.clone(),
96                context_window: 128_000,
97                max_output_tokens: Some(16_384),
98                supports_vision: false,
99                supports_tools: true,
100                supports_streaming: true,
101                input_cost_per_million: None,
102                output_cost_per_million: None,
103            })
104            .collect()
105    }
106
107    fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
108        let mut result = Vec::new();
109
110        for msg in messages {
111            let content = msg
112                .content
113                .iter()
114                .filter_map(|p| match p {
115                    ContentPart::Text { text } => Some(text.clone()),
116                    _ => None,
117                })
118                .collect::<Vec<_>>()
119                .join("\n");
120
121            match msg.role {
122                Role::System => {
123                    result.push(
124                        ChatCompletionRequestSystemMessageArgs::default()
125                            .content(content)
126                            .build()?
127                            .into(),
128                    );
129                }
130                Role::User => {
131                    result.push(
132                        ChatCompletionRequestUserMessageArgs::default()
133                            .content(content)
134                            .build()?
135                            .into(),
136                    );
137                }
138                Role::Assistant => {
139                    let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
140                        .content
141                        .iter()
142                        .filter_map(|p| match p {
143                            ContentPart::ToolCall {
144                                id,
145                                name,
146                                arguments,
147                            } => Some(ChatCompletionMessageToolCalls::Function(
148                                ChatCompletionMessageToolCall {
149                                    id: id.clone(),
150                                    function: FunctionCall {
151                                        name: name.clone(),
152                                        arguments: arguments.clone(),
153                                    },
154                                },
155                            )),
156                            _ => None,
157                        })
158                        .collect();
159
160                    let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
161                    if !content.is_empty() {
162                        builder.content(content);
163                    }
164                    if !tool_calls.is_empty() {
165                        builder.tool_calls(tool_calls);
166                    }
167                    result.push(builder.build()?.into());
168                }
169                Role::Tool => {
170                    for part in &msg.content {
171                        if let ContentPart::ToolResult {
172                            tool_call_id,
173                            content,
174                        } = part
175                        {
176                            result.push(
177                                ChatCompletionRequestToolMessageArgs::default()
178                                    .tool_call_id(tool_call_id.clone())
179                                    .content(content.clone())
180                                    .build()?
181                                    .into(),
182                            );
183                        }
184                    }
185                }
186            }
187        }
188
189        Ok(result)
190    }
191
192    fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
193        let mut result = Vec::new();
194        for tool in tools {
195            result.push(ChatCompletionTools::Function(ChatCompletionTool {
196                function: FunctionObjectArgs::default()
197                    .name(&tool.name)
198                    .description(&tool.description)
199                    .parameters(tool.parameters.clone())
200                    .build()?,
201            }));
202        }
203        Ok(result)
204    }
205}
206
207#[async_trait]
208impl Provider for OpenAIProvider {
209    fn name(&self) -> &str {
210        &self.provider_name
211    }
212
213    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
214        // For non-OpenAI providers, return provider-specific model defaults.
215        // Note: async-openai 0.32 does not expose a stable models list API across
216        // all OpenAI-compatible endpoints.
217        if self.provider_name != "openai" {
218            return Ok(self.provider_default_models());
219        }
220
221        // OpenAI default models
222        Ok(vec![
223            ModelInfo {
224                id: "gpt-4o".to_string(),
225                name: "GPT-4o".to_string(),
226                provider: "openai".to_string(),
227                context_window: 128_000,
228                max_output_tokens: Some(16_384),
229                supports_vision: true,
230                supports_tools: true,
231                supports_streaming: true,
232                input_cost_per_million: Some(2.5),
233                output_cost_per_million: Some(10.0),
234            },
235            ModelInfo {
236                id: "gpt-4o-mini".to_string(),
237                name: "GPT-4o Mini".to_string(),
238                provider: "openai".to_string(),
239                context_window: 128_000,
240                max_output_tokens: Some(16_384),
241                supports_vision: true,
242                supports_tools: true,
243                supports_streaming: true,
244                input_cost_per_million: Some(0.15),
245                output_cost_per_million: Some(0.6),
246            },
247            ModelInfo {
248                id: "o1".to_string(),
249                name: "o1".to_string(),
250                provider: "openai".to_string(),
251                context_window: 200_000,
252                max_output_tokens: Some(100_000),
253                supports_vision: true,
254                supports_tools: true,
255                supports_streaming: true,
256                input_cost_per_million: Some(15.0),
257                output_cost_per_million: Some(60.0),
258            },
259        ])
260    }
261
262    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
263        let messages = Self::convert_messages(&request.messages)?;
264        let tools = Self::convert_tools(&request.tools)?;
265
266        let mut req_builder = CreateChatCompletionRequestArgs::default();
267        req_builder.model(&request.model).messages(messages);
268
269        // Pass tools to the API if provided
270        if !tools.is_empty() {
271            req_builder.tools(tools);
272        }
273        if let Some(temp) = request.temperature {
274            req_builder.temperature(temp);
275        }
276        if let Some(max) = request.max_tokens {
277            req_builder.max_completion_tokens(max as u32);
278        }
279
280        let response = self.client.chat().create(req_builder.build()?).await?;
281
282        let choice = response
283            .choices
284            .first()
285            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
286
287        let mut content = Vec::new();
288        let mut has_tool_calls = false;
289
290        if let Some(text) = &choice.message.content {
291            content.push(ContentPart::Text { text: text.clone() });
292        }
293        if let Some(tool_calls) = &choice.message.tool_calls {
294            has_tool_calls = !tool_calls.is_empty();
295            for tc in tool_calls {
296                if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
297                    content.push(ContentPart::ToolCall {
298                        id: func_call.id.clone(),
299                        name: func_call.function.name.clone(),
300                        arguments: func_call.function.arguments.clone(),
301                    });
302                }
303            }
304        }
305
306        // Determine finish reason based on response
307        let finish_reason = if has_tool_calls {
308            FinishReason::ToolCalls
309        } else {
310            match choice.finish_reason {
311                Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
312                Some(OpenAIFinishReason::Length) => FinishReason::Length,
313                Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
314                Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
315                _ => FinishReason::Stop,
316            }
317        };
318
319        Ok(CompletionResponse {
320            message: Message {
321                role: Role::Assistant,
322                content,
323            },
324            usage: Usage {
325                prompt_tokens: response
326                    .usage
327                    .as_ref()
328                    .map(|u| u.prompt_tokens as usize)
329                    .unwrap_or(0),
330                completion_tokens: response
331                    .usage
332                    .as_ref()
333                    .map(|u| u.completion_tokens as usize)
334                    .unwrap_or(0),
335                total_tokens: response
336                    .usage
337                    .as_ref()
338                    .map(|u| u.total_tokens as usize)
339                    .unwrap_or(0),
340                ..Default::default()
341            },
342            finish_reason,
343        })
344    }
345
346    async fn complete_stream(
347        &self,
348        request: CompletionRequest,
349    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
350        tracing::debug!(
351            provider = %self.provider_name,
352            model = %request.model,
353            message_count = request.messages.len(),
354            "Starting streaming completion request"
355        );
356
357        let messages = Self::convert_messages(&request.messages)?;
358
359        let mut req_builder = CreateChatCompletionRequestArgs::default();
360        req_builder
361            .model(&request.model)
362            .messages(messages)
363            .stream(true);
364
365        if let Some(temp) = request.temperature {
366            req_builder.temperature(temp);
367        }
368
369        let stream = self
370            .client
371            .chat()
372            .create_stream(req_builder.build()?)
373            .await?;
374
375        Ok(stream
376            .map(|result| match result {
377                Ok(response) => {
378                    if let Some(choice) = response.choices.first() {
379                        if let Some(content) = &choice.delta.content {
380                            return StreamChunk::Text(content.clone());
381                        }
382                    }
383                    StreamChunk::Text(String::new())
384                }
385                Err(e) => StreamChunk::Error(e.to_string()),
386            })
387            .boxed())
388    }
389}