Skip to main content

codetether_agent/provider/
openai.rs

1//! OpenAI provider implementation
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9    config::OpenAIConfig,
10    types::chat::{
11        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs,
12        ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
13        ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
14        ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs,
15        FinishReason as OpenAIFinishReason, FunctionCall, FunctionObjectArgs,
16    },
17    Client,
18};
19use async_trait::async_trait;
20use futures::StreamExt;
21
22pub struct OpenAIProvider {
23    client: Client<OpenAIConfig>,
24    provider_name: String,
25}
26
27impl std::fmt::Debug for OpenAIProvider {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("OpenAIProvider")
30            .field("provider_name", &self.provider_name)
31            .field("client", &"<async_openai::Client>")
32            .finish()
33    }
34}
35
36impl OpenAIProvider {
37    pub fn new(api_key: String) -> Result<Self> {
38        tracing::debug!(provider = "openai", api_key_len = api_key.len(), "Creating OpenAI provider");
39        let config = OpenAIConfig::new().with_api_key(api_key);
40        Ok(Self {
41            client: Client::with_config(config),
42            provider_name: "openai".to_string(),
43        })
44    }
45
46    /// Create with custom base URL (for OpenAI-compatible providers like Moonshot)
47    pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
48        tracing::debug!(
49            provider = provider_name,
50            base_url = %base_url,
51            api_key_len = api_key.len(),
52            "Creating OpenAI-compatible provider"
53        );
54        let config = OpenAIConfig::new()
55            .with_api_key(api_key)
56            .with_api_base(base_url);
57        Ok(Self {
58            client: Client::with_config(config),
59            provider_name: provider_name.to_string(),
60        })
61    }
62
63    fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
64        let mut result = Vec::new();
65        
66        for msg in messages {
67            let content = msg
68                .content
69                .iter()
70                .filter_map(|p| match p {
71                    ContentPart::Text { text } => Some(text.clone()),
72                    _ => None,
73                })
74                .collect::<Vec<_>>()
75                .join("\n");
76
77            match msg.role {
78                Role::System => {
79                    result.push(
80                        ChatCompletionRequestSystemMessageArgs::default()
81                            .content(content)
82                            .build()?
83                            .into(),
84                    );
85                }
86                Role::User => {
87                    result.push(
88                        ChatCompletionRequestUserMessageArgs::default()
89                            .content(content)
90                            .build()?
91                            .into(),
92                    );
93                }
94                Role::Assistant => {
95                    let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
96                        .content
97                        .iter()
98                        .filter_map(|p| match p {
99                            ContentPart::ToolCall { id, name, arguments } => {
100                                Some(ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall {
101                                    id: id.clone(),
102                                    function: FunctionCall {
103                                        name: name.clone(),
104                                        arguments: arguments.clone(),
105                                    },
106                                }))
107                            }
108                            _ => None,
109                        })
110                        .collect();
111
112                    let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
113                    if !content.is_empty() {
114                        builder.content(content);
115                    }
116                    if !tool_calls.is_empty() {
117                        builder.tool_calls(tool_calls);
118                    }
119                    result.push(builder.build()?.into());
120                }
121                Role::Tool => {
122                    for part in &msg.content {
123                        if let ContentPart::ToolResult { tool_call_id, content } = part {
124                            result.push(
125                                ChatCompletionRequestToolMessageArgs::default()
126                                    .tool_call_id(tool_call_id.clone())
127                                    .content(content.clone())
128                                    .build()?
129                                    .into(),
130                            );
131                        }
132                    }
133                }
134            }
135        }
136        
137        Ok(result)
138    }
139
140    fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
141        let mut result = Vec::new();
142        for tool in tools {
143            result.push(ChatCompletionTools::Function(ChatCompletionTool {
144                function: FunctionObjectArgs::default()
145                    .name(&tool.name)
146                    .description(&tool.description)
147                    .parameters(tool.parameters.clone())
148                    .build()?,
149            }));
150        }
151        Ok(result)
152    }
153}
154
155#[async_trait]
156impl Provider for OpenAIProvider {
157    fn name(&self) -> &str {
158        &self.provider_name
159    }
160
161    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
162        // Return commonly used models
163        Ok(vec![
164            ModelInfo {
165                id: "gpt-4o".to_string(),
166                name: "GPT-4o".to_string(),
167                provider: "openai".to_string(),
168                context_window: 128_000,
169                max_output_tokens: Some(16_384),
170                supports_vision: true,
171                supports_tools: true,
172                supports_streaming: true,
173                input_cost_per_million: Some(2.5),
174                output_cost_per_million: Some(10.0),
175            },
176            ModelInfo {
177                id: "gpt-4o-mini".to_string(),
178                name: "GPT-4o Mini".to_string(),
179                provider: "openai".to_string(),
180                context_window: 128_000,
181                max_output_tokens: Some(16_384),
182                supports_vision: true,
183                supports_tools: true,
184                supports_streaming: true,
185                input_cost_per_million: Some(0.15),
186                output_cost_per_million: Some(0.6),
187            },
188            ModelInfo {
189                id: "o1".to_string(),
190                name: "o1".to_string(),
191                provider: "openai".to_string(),
192                context_window: 200_000,
193                max_output_tokens: Some(100_000),
194                supports_vision: true,
195                supports_tools: true,
196                supports_streaming: true,
197                input_cost_per_million: Some(15.0),
198                output_cost_per_million: Some(60.0),
199            },
200        ])
201    }
202
203    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
204        let messages = Self::convert_messages(&request.messages)?;
205        let tools = Self::convert_tools(&request.tools)?;
206
207        let mut req_builder = CreateChatCompletionRequestArgs::default();
208        req_builder.model(&request.model).messages(messages);
209
210        // Pass tools to the API if provided
211        if !tools.is_empty() {
212            req_builder.tools(tools);
213        }
214        if let Some(temp) = request.temperature {
215            req_builder.temperature(temp);
216        }
217        if let Some(max) = request.max_tokens {
218            req_builder.max_completion_tokens(max as u32);
219        }
220
221        let response = self.client.chat().create(req_builder.build()?).await?;
222
223        let choice = response.choices.first().ok_or_else(|| anyhow::anyhow!("No choices"))?;
224
225        let mut content = Vec::new();
226        let mut has_tool_calls = false;
227        
228        if let Some(text) = &choice.message.content {
229            content.push(ContentPart::Text { text: text.clone() });
230        }
231        if let Some(tool_calls) = &choice.message.tool_calls {
232            has_tool_calls = !tool_calls.is_empty();
233            for tc in tool_calls {
234                if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
235                    content.push(ContentPart::ToolCall {
236                        id: func_call.id.clone(),
237                        name: func_call.function.name.clone(),
238                        arguments: func_call.function.arguments.clone(),
239                    });
240                }
241            }
242        }
243
244        // Determine finish reason based on response
245        let finish_reason = if has_tool_calls {
246            FinishReason::ToolCalls
247        } else {
248            match choice.finish_reason {
249                Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
250                Some(OpenAIFinishReason::Length) => FinishReason::Length,
251                Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
252                Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
253                _ => FinishReason::Stop,
254            }
255        };
256
257        Ok(CompletionResponse {
258            message: Message {
259                role: Role::Assistant,
260                content,
261            },
262            usage: Usage {
263                prompt_tokens: response.usage.as_ref().map(|u| u.prompt_tokens as usize).unwrap_or(0),
264                completion_tokens: response.usage.as_ref().map(|u| u.completion_tokens as usize).unwrap_or(0),
265                total_tokens: response.usage.as_ref().map(|u| u.total_tokens as usize).unwrap_or(0),
266                ..Default::default()
267            },
268            finish_reason,
269        })
270    }
271
272    async fn complete_stream(
273        &self,
274        request: CompletionRequest,
275    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
276        tracing::debug!(
277            provider = %self.provider_name,
278            model = %request.model,
279            message_count = request.messages.len(),
280            "Starting streaming completion request"
281        );
282        
283        let messages = Self::convert_messages(&request.messages)?;
284
285        let mut req_builder = CreateChatCompletionRequestArgs::default();
286        req_builder.model(&request.model).messages(messages).stream(true);
287
288        if let Some(temp) = request.temperature {
289            req_builder.temperature(temp);
290        }
291
292        let stream = self.client.chat().create_stream(req_builder.build()?).await?;
293
294        Ok(stream
295            .map(|result| match result {
296                Ok(response) => {
297                    if let Some(choice) = response.choices.first() {
298                        if let Some(content) = &choice.delta.content {
299                            return StreamChunk::Text(content.clone());
300                        }
301                    }
302                    StreamChunk::Text(String::new())
303                }
304                Err(e) => StreamChunk::Error(e.to_string()),
305            })
306            .boxed())
307    }
308}