codetether_agent/provider/
openai.rs1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9 Client,
10 config::OpenAIConfig,
11 types::chat::{
12 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16 CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17 FunctionObjectArgs,
18 },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24 client: Client<OpenAIConfig>,
25 provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("OpenAIProvider")
31 .field("provider_name", &self.provider_name)
32 .field("client", &"<async_openai::Client>")
33 .finish()
34 }
35}
36
37impl OpenAIProvider {
38 pub fn new(api_key: String) -> Result<Self> {
39 tracing::debug!(
40 provider = "openai",
41 api_key_len = api_key.len(),
42 "Creating OpenAI provider"
43 );
44 let config = OpenAIConfig::new().with_api_key(api_key);
45 Ok(Self {
46 client: Client::with_config(config),
47 provider_name: "openai".to_string(),
48 })
49 }
50
51 pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53 tracing::debug!(
54 provider = provider_name,
55 base_url = %base_url,
56 api_key_len = api_key.len(),
57 "Creating OpenAI-compatible provider"
58 );
59 let config = OpenAIConfig::new()
60 .with_api_key(api_key)
61 .with_api_base(base_url);
62 Ok(Self {
63 client: Client::with_config(config),
64 provider_name: provider_name.to_string(),
65 })
66 }
67
68 fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
69 let mut result = Vec::new();
70
71 for msg in messages {
72 let content = msg
73 .content
74 .iter()
75 .filter_map(|p| match p {
76 ContentPart::Text { text } => Some(text.clone()),
77 _ => None,
78 })
79 .collect::<Vec<_>>()
80 .join("\n");
81
82 match msg.role {
83 Role::System => {
84 result.push(
85 ChatCompletionRequestSystemMessageArgs::default()
86 .content(content)
87 .build()?
88 .into(),
89 );
90 }
91 Role::User => {
92 result.push(
93 ChatCompletionRequestUserMessageArgs::default()
94 .content(content)
95 .build()?
96 .into(),
97 );
98 }
99 Role::Assistant => {
100 let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
101 .content
102 .iter()
103 .filter_map(|p| match p {
104 ContentPart::ToolCall {
105 id,
106 name,
107 arguments,
108 } => Some(ChatCompletionMessageToolCalls::Function(
109 ChatCompletionMessageToolCall {
110 id: id.clone(),
111 function: FunctionCall {
112 name: name.clone(),
113 arguments: arguments.clone(),
114 },
115 },
116 )),
117 _ => None,
118 })
119 .collect();
120
121 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
122 if !content.is_empty() {
123 builder.content(content);
124 }
125 if !tool_calls.is_empty() {
126 builder.tool_calls(tool_calls);
127 }
128 result.push(builder.build()?.into());
129 }
130 Role::Tool => {
131 for part in &msg.content {
132 if let ContentPart::ToolResult {
133 tool_call_id,
134 content,
135 } = part
136 {
137 result.push(
138 ChatCompletionRequestToolMessageArgs::default()
139 .tool_call_id(tool_call_id.clone())
140 .content(content.clone())
141 .build()?
142 .into(),
143 );
144 }
145 }
146 }
147 }
148 }
149
150 Ok(result)
151 }
152
153 fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
154 let mut result = Vec::new();
155 for tool in tools {
156 result.push(ChatCompletionTools::Function(ChatCompletionTool {
157 function: FunctionObjectArgs::default()
158 .name(&tool.name)
159 .description(&tool.description)
160 .parameters(tool.parameters.clone())
161 .build()?,
162 }));
163 }
164 Ok(result)
165 }
166}
167
168#[async_trait]
169impl Provider for OpenAIProvider {
170 fn name(&self) -> &str {
171 &self.provider_name
172 }
173
174 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
175 Ok(vec![
177 ModelInfo {
178 id: "gpt-4o".to_string(),
179 name: "GPT-4o".to_string(),
180 provider: "openai".to_string(),
181 context_window: 128_000,
182 max_output_tokens: Some(16_384),
183 supports_vision: true,
184 supports_tools: true,
185 supports_streaming: true,
186 input_cost_per_million: Some(2.5),
187 output_cost_per_million: Some(10.0),
188 },
189 ModelInfo {
190 id: "gpt-4o-mini".to_string(),
191 name: "GPT-4o Mini".to_string(),
192 provider: "openai".to_string(),
193 context_window: 128_000,
194 max_output_tokens: Some(16_384),
195 supports_vision: true,
196 supports_tools: true,
197 supports_streaming: true,
198 input_cost_per_million: Some(0.15),
199 output_cost_per_million: Some(0.6),
200 },
201 ModelInfo {
202 id: "o1".to_string(),
203 name: "o1".to_string(),
204 provider: "openai".to_string(),
205 context_window: 200_000,
206 max_output_tokens: Some(100_000),
207 supports_vision: true,
208 supports_tools: true,
209 supports_streaming: true,
210 input_cost_per_million: Some(15.0),
211 output_cost_per_million: Some(60.0),
212 },
213 ])
214 }
215
216 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
217 let messages = Self::convert_messages(&request.messages)?;
218 let tools = Self::convert_tools(&request.tools)?;
219
220 let mut req_builder = CreateChatCompletionRequestArgs::default();
221 req_builder.model(&request.model).messages(messages);
222
223 if !tools.is_empty() {
225 req_builder.tools(tools);
226 }
227 if let Some(temp) = request.temperature {
228 req_builder.temperature(temp);
229 }
230 if let Some(max) = request.max_tokens {
231 req_builder.max_completion_tokens(max as u32);
232 }
233
234 let response = self.client.chat().create(req_builder.build()?).await?;
235
236 let choice = response
237 .choices
238 .first()
239 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
240
241 let mut content = Vec::new();
242 let mut has_tool_calls = false;
243
244 if let Some(text) = &choice.message.content {
245 content.push(ContentPart::Text { text: text.clone() });
246 }
247 if let Some(tool_calls) = &choice.message.tool_calls {
248 has_tool_calls = !tool_calls.is_empty();
249 for tc in tool_calls {
250 if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
251 content.push(ContentPart::ToolCall {
252 id: func_call.id.clone(),
253 name: func_call.function.name.clone(),
254 arguments: func_call.function.arguments.clone(),
255 });
256 }
257 }
258 }
259
260 let finish_reason = if has_tool_calls {
262 FinishReason::ToolCalls
263 } else {
264 match choice.finish_reason {
265 Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
266 Some(OpenAIFinishReason::Length) => FinishReason::Length,
267 Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
268 Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
269 _ => FinishReason::Stop,
270 }
271 };
272
273 Ok(CompletionResponse {
274 message: Message {
275 role: Role::Assistant,
276 content,
277 },
278 usage: Usage {
279 prompt_tokens: response
280 .usage
281 .as_ref()
282 .map(|u| u.prompt_tokens as usize)
283 .unwrap_or(0),
284 completion_tokens: response
285 .usage
286 .as_ref()
287 .map(|u| u.completion_tokens as usize)
288 .unwrap_or(0),
289 total_tokens: response
290 .usage
291 .as_ref()
292 .map(|u| u.total_tokens as usize)
293 .unwrap_or(0),
294 ..Default::default()
295 },
296 finish_reason,
297 })
298 }
299
300 async fn complete_stream(
301 &self,
302 request: CompletionRequest,
303 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
304 tracing::debug!(
305 provider = %self.provider_name,
306 model = %request.model,
307 message_count = request.messages.len(),
308 "Starting streaming completion request"
309 );
310
311 let messages = Self::convert_messages(&request.messages)?;
312
313 let mut req_builder = CreateChatCompletionRequestArgs::default();
314 req_builder
315 .model(&request.model)
316 .messages(messages)
317 .stream(true);
318
319 if let Some(temp) = request.temperature {
320 req_builder.temperature(temp);
321 }
322
323 let stream = self
324 .client
325 .chat()
326 .create_stream(req_builder.build()?)
327 .await?;
328
329 Ok(stream
330 .map(|result| match result {
331 Ok(response) => {
332 if let Some(choice) = response.choices.first() {
333 if let Some(content) = &choice.delta.content {
334 return StreamChunk::Text(content.clone());
335 }
336 }
337 StreamChunk::Text(String::new())
338 }
339 Err(e) => StreamChunk::Error(e.to_string()),
340 })
341 .boxed())
342 }
343}