codetether_agent/provider/
openai.rs1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9 config::OpenAIConfig,
10 types::chat::{
11 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs,
12 ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
13 ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
14 ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequestArgs,
15 FinishReason as OpenAIFinishReason, FunctionCall, FunctionObjectArgs,
16 },
17 Client,
18};
19use async_trait::async_trait;
20use futures::StreamExt;
21
22pub struct OpenAIProvider {
23 client: Client<OpenAIConfig>,
24 provider_name: String,
25}
26
27impl std::fmt::Debug for OpenAIProvider {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("OpenAIProvider")
30 .field("provider_name", &self.provider_name)
31 .field("client", &"<async_openai::Client>")
32 .finish()
33 }
34}
35
36impl OpenAIProvider {
37 pub fn new(api_key: String) -> Result<Self> {
38 tracing::debug!(provider = "openai", api_key_len = api_key.len(), "Creating OpenAI provider");
39 let config = OpenAIConfig::new().with_api_key(api_key);
40 Ok(Self {
41 client: Client::with_config(config),
42 provider_name: "openai".to_string(),
43 })
44 }
45
46 pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
48 tracing::debug!(
49 provider = provider_name,
50 base_url = %base_url,
51 api_key_len = api_key.len(),
52 "Creating OpenAI-compatible provider"
53 );
54 let config = OpenAIConfig::new()
55 .with_api_key(api_key)
56 .with_api_base(base_url);
57 Ok(Self {
58 client: Client::with_config(config),
59 provider_name: provider_name.to_string(),
60 })
61 }
62
63 fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
64 let mut result = Vec::new();
65
66 for msg in messages {
67 let content = msg
68 .content
69 .iter()
70 .filter_map(|p| match p {
71 ContentPart::Text { text } => Some(text.clone()),
72 _ => None,
73 })
74 .collect::<Vec<_>>()
75 .join("\n");
76
77 match msg.role {
78 Role::System => {
79 result.push(
80 ChatCompletionRequestSystemMessageArgs::default()
81 .content(content)
82 .build()?
83 .into(),
84 );
85 }
86 Role::User => {
87 result.push(
88 ChatCompletionRequestUserMessageArgs::default()
89 .content(content)
90 .build()?
91 .into(),
92 );
93 }
94 Role::Assistant => {
95 let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
96 .content
97 .iter()
98 .filter_map(|p| match p {
99 ContentPart::ToolCall { id, name, arguments } => {
100 Some(ChatCompletionMessageToolCalls::Function(ChatCompletionMessageToolCall {
101 id: id.clone(),
102 function: FunctionCall {
103 name: name.clone(),
104 arguments: arguments.clone(),
105 },
106 }))
107 }
108 _ => None,
109 })
110 .collect();
111
112 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
113 if !content.is_empty() {
114 builder.content(content);
115 }
116 if !tool_calls.is_empty() {
117 builder.tool_calls(tool_calls);
118 }
119 result.push(builder.build()?.into());
120 }
121 Role::Tool => {
122 for part in &msg.content {
123 if let ContentPart::ToolResult { tool_call_id, content } = part {
124 result.push(
125 ChatCompletionRequestToolMessageArgs::default()
126 .tool_call_id(tool_call_id.clone())
127 .content(content.clone())
128 .build()?
129 .into(),
130 );
131 }
132 }
133 }
134 }
135 }
136
137 Ok(result)
138 }
139
140 fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
141 let mut result = Vec::new();
142 for tool in tools {
143 result.push(ChatCompletionTools::Function(ChatCompletionTool {
144 function: FunctionObjectArgs::default()
145 .name(&tool.name)
146 .description(&tool.description)
147 .parameters(tool.parameters.clone())
148 .build()?,
149 }));
150 }
151 Ok(result)
152 }
153}
154
155#[async_trait]
156impl Provider for OpenAIProvider {
157 fn name(&self) -> &str {
158 &self.provider_name
159 }
160
161 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
162 Ok(vec![
164 ModelInfo {
165 id: "gpt-4o".to_string(),
166 name: "GPT-4o".to_string(),
167 provider: "openai".to_string(),
168 context_window: 128_000,
169 max_output_tokens: Some(16_384),
170 supports_vision: true,
171 supports_tools: true,
172 supports_streaming: true,
173 input_cost_per_million: Some(2.5),
174 output_cost_per_million: Some(10.0),
175 },
176 ModelInfo {
177 id: "gpt-4o-mini".to_string(),
178 name: "GPT-4o Mini".to_string(),
179 provider: "openai".to_string(),
180 context_window: 128_000,
181 max_output_tokens: Some(16_384),
182 supports_vision: true,
183 supports_tools: true,
184 supports_streaming: true,
185 input_cost_per_million: Some(0.15),
186 output_cost_per_million: Some(0.6),
187 },
188 ModelInfo {
189 id: "o1".to_string(),
190 name: "o1".to_string(),
191 provider: "openai".to_string(),
192 context_window: 200_000,
193 max_output_tokens: Some(100_000),
194 supports_vision: true,
195 supports_tools: true,
196 supports_streaming: true,
197 input_cost_per_million: Some(15.0),
198 output_cost_per_million: Some(60.0),
199 },
200 ])
201 }
202
203 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
204 let messages = Self::convert_messages(&request.messages)?;
205 let tools = Self::convert_tools(&request.tools)?;
206
207 let mut req_builder = CreateChatCompletionRequestArgs::default();
208 req_builder.model(&request.model).messages(messages);
209
210 if !tools.is_empty() {
212 req_builder.tools(tools);
213 }
214 if let Some(temp) = request.temperature {
215 req_builder.temperature(temp);
216 }
217 if let Some(max) = request.max_tokens {
218 req_builder.max_completion_tokens(max as u32);
219 }
220
221 let response = self.client.chat().create(req_builder.build()?).await?;
222
223 let choice = response.choices.first().ok_or_else(|| anyhow::anyhow!("No choices"))?;
224
225 let mut content = Vec::new();
226 let mut has_tool_calls = false;
227
228 if let Some(text) = &choice.message.content {
229 content.push(ContentPart::Text { text: text.clone() });
230 }
231 if let Some(tool_calls) = &choice.message.tool_calls {
232 has_tool_calls = !tool_calls.is_empty();
233 for tc in tool_calls {
234 if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
235 content.push(ContentPart::ToolCall {
236 id: func_call.id.clone(),
237 name: func_call.function.name.clone(),
238 arguments: func_call.function.arguments.clone(),
239 });
240 }
241 }
242 }
243
244 let finish_reason = if has_tool_calls {
246 FinishReason::ToolCalls
247 } else {
248 match choice.finish_reason {
249 Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
250 Some(OpenAIFinishReason::Length) => FinishReason::Length,
251 Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
252 Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
253 _ => FinishReason::Stop,
254 }
255 };
256
257 Ok(CompletionResponse {
258 message: Message {
259 role: Role::Assistant,
260 content,
261 },
262 usage: Usage {
263 prompt_tokens: response.usage.as_ref().map(|u| u.prompt_tokens as usize).unwrap_or(0),
264 completion_tokens: response.usage.as_ref().map(|u| u.completion_tokens as usize).unwrap_or(0),
265 total_tokens: response.usage.as_ref().map(|u| u.total_tokens as usize).unwrap_or(0),
266 ..Default::default()
267 },
268 finish_reason,
269 })
270 }
271
272 async fn complete_stream(
273 &self,
274 request: CompletionRequest,
275 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
276 tracing::debug!(
277 provider = %self.provider_name,
278 model = %request.model,
279 message_count = request.messages.len(),
280 "Starting streaming completion request"
281 );
282
283 let messages = Self::convert_messages(&request.messages)?;
284
285 let mut req_builder = CreateChatCompletionRequestArgs::default();
286 req_builder.model(&request.model).messages(messages).stream(true);
287
288 if let Some(temp) = request.temperature {
289 req_builder.temperature(temp);
290 }
291
292 let stream = self.client.chat().create_stream(req_builder.build()?).await?;
293
294 Ok(stream
295 .map(|result| match result {
296 Ok(response) => {
297 if let Some(choice) = response.choices.first() {
298 if let Some(content) = &choice.delta.content {
299 return StreamChunk::Text(content.clone());
300 }
301 }
302 StreamChunk::Text(String::new())
303 }
304 Err(e) => StreamChunk::Error(e.to_string()),
305 })
306 .boxed())
307 }
308}