codetether_agent/provider/
openai.rs1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9 Client,
10 config::OpenAIConfig,
11 types::chat::{
12 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16 CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17 FunctionObjectArgs,
18 },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24 client: Client<OpenAIConfig>,
25 provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("OpenAIProvider")
31 .field("provider_name", &self.provider_name)
32 .field("client", &"<async_openai::Client>")
33 .finish()
34 }
35}
36
37impl OpenAIProvider {
38 pub fn new(api_key: String) -> Result<Self> {
39 tracing::debug!(
40 provider = "openai",
41 api_key_len = api_key.len(),
42 "Creating OpenAI provider"
43 );
44 let config = OpenAIConfig::new().with_api_key(api_key);
45 Ok(Self {
46 client: Client::with_config(config),
47 provider_name: "openai".to_string(),
48 })
49 }
50
51 pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53 tracing::debug!(
54 provider = provider_name,
55 base_url = %base_url,
56 api_key_len = api_key.len(),
57 "Creating OpenAI-compatible provider"
58 );
59 let config = OpenAIConfig::new()
60 .with_api_key(api_key)
61 .with_api_base(base_url);
62 Ok(Self {
63 client: Client::with_config(config),
64 provider_name: provider_name.to_string(),
65 })
66 }
67
68 fn provider_default_models(&self) -> Vec<ModelInfo> {
70 let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71 "cerebras" => vec![
72 ("llama3.1-8b", "Llama 3.1 8B"),
73 ("llama-3.3-70b", "Llama 3.3 70B"),
74 ("qwen-3-32b", "Qwen 3 32B"),
75 ("gpt-oss-120b", "GPT-OSS 120B"),
76 ],
77 "novita" => vec![
78 ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
79 ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
80 ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
81 ("qwen/qwen-2.5-72b-instruct", "Qwen 2.5 72B"),
82 ],
83 "minimax" => vec![
84 ("MiniMax-M1-80k", "MiniMax M1 80k"),
85 ("MiniMax-Text-01", "MiniMax Text 01"),
86 ],
87 _ => vec![],
88 };
89
90 models
91 .into_iter()
92 .map(|(id, name)| ModelInfo {
93 id: id.to_string(),
94 name: name.to_string(),
95 provider: self.provider_name.clone(),
96 context_window: 128_000,
97 max_output_tokens: Some(16_384),
98 supports_vision: false,
99 supports_tools: true,
100 supports_streaming: true,
101 input_cost_per_million: None,
102 output_cost_per_million: None,
103 })
104 .collect()
105 }
106
107 fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
108 let mut result = Vec::new();
109
110 for msg in messages {
111 let content = msg
112 .content
113 .iter()
114 .filter_map(|p| match p {
115 ContentPart::Text { text } => Some(text.clone()),
116 _ => None,
117 })
118 .collect::<Vec<_>>()
119 .join("\n");
120
121 match msg.role {
122 Role::System => {
123 result.push(
124 ChatCompletionRequestSystemMessageArgs::default()
125 .content(content)
126 .build()?
127 .into(),
128 );
129 }
130 Role::User => {
131 result.push(
132 ChatCompletionRequestUserMessageArgs::default()
133 .content(content)
134 .build()?
135 .into(),
136 );
137 }
138 Role::Assistant => {
139 let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
140 .content
141 .iter()
142 .filter_map(|p| match p {
143 ContentPart::ToolCall {
144 id,
145 name,
146 arguments,
147 } => Some(ChatCompletionMessageToolCalls::Function(
148 ChatCompletionMessageToolCall {
149 id: id.clone(),
150 function: FunctionCall {
151 name: name.clone(),
152 arguments: arguments.clone(),
153 },
154 },
155 )),
156 _ => None,
157 })
158 .collect();
159
160 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
161 if !content.is_empty() {
162 builder.content(content);
163 }
164 if !tool_calls.is_empty() {
165 builder.tool_calls(tool_calls);
166 }
167 result.push(builder.build()?.into());
168 }
169 Role::Tool => {
170 for part in &msg.content {
171 if let ContentPart::ToolResult {
172 tool_call_id,
173 content,
174 } = part
175 {
176 result.push(
177 ChatCompletionRequestToolMessageArgs::default()
178 .tool_call_id(tool_call_id.clone())
179 .content(content.clone())
180 .build()?
181 .into(),
182 );
183 }
184 }
185 }
186 }
187 }
188
189 Ok(result)
190 }
191
192 fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
193 let mut result = Vec::new();
194 for tool in tools {
195 result.push(ChatCompletionTools::Function(ChatCompletionTool {
196 function: FunctionObjectArgs::default()
197 .name(&tool.name)
198 .description(&tool.description)
199 .parameters(tool.parameters.clone())
200 .build()?,
201 }));
202 }
203 Ok(result)
204 }
205}
206
207#[async_trait]
208impl Provider for OpenAIProvider {
209 fn name(&self) -> &str {
210 &self.provider_name
211 }
212
213 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
214 if self.provider_name != "openai" {
218 return Ok(self.provider_default_models());
219 }
220
221 Ok(vec![
223 ModelInfo {
224 id: "gpt-4o".to_string(),
225 name: "GPT-4o".to_string(),
226 provider: "openai".to_string(),
227 context_window: 128_000,
228 max_output_tokens: Some(16_384),
229 supports_vision: true,
230 supports_tools: true,
231 supports_streaming: true,
232 input_cost_per_million: Some(2.5),
233 output_cost_per_million: Some(10.0),
234 },
235 ModelInfo {
236 id: "gpt-4o-mini".to_string(),
237 name: "GPT-4o Mini".to_string(),
238 provider: "openai".to_string(),
239 context_window: 128_000,
240 max_output_tokens: Some(16_384),
241 supports_vision: true,
242 supports_tools: true,
243 supports_streaming: true,
244 input_cost_per_million: Some(0.15),
245 output_cost_per_million: Some(0.6),
246 },
247 ModelInfo {
248 id: "o1".to_string(),
249 name: "o1".to_string(),
250 provider: "openai".to_string(),
251 context_window: 200_000,
252 max_output_tokens: Some(100_000),
253 supports_vision: true,
254 supports_tools: true,
255 supports_streaming: true,
256 input_cost_per_million: Some(15.0),
257 output_cost_per_million: Some(60.0),
258 },
259 ])
260 }
261
262 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
263 let messages = Self::convert_messages(&request.messages)?;
264 let tools = Self::convert_tools(&request.tools)?;
265
266 let mut req_builder = CreateChatCompletionRequestArgs::default();
267 req_builder.model(&request.model).messages(messages);
268
269 if !tools.is_empty() {
271 req_builder.tools(tools);
272 }
273 if let Some(temp) = request.temperature {
274 req_builder.temperature(temp);
275 }
276 if let Some(max) = request.max_tokens {
277 req_builder.max_completion_tokens(max as u32);
278 }
279
280 let response = self.client.chat().create(req_builder.build()?).await?;
281
282 let choice = response
283 .choices
284 .first()
285 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
286
287 let mut content = Vec::new();
288 let mut has_tool_calls = false;
289
290 if let Some(text) = &choice.message.content {
291 content.push(ContentPart::Text { text: text.clone() });
292 }
293 if let Some(tool_calls) = &choice.message.tool_calls {
294 has_tool_calls = !tool_calls.is_empty();
295 for tc in tool_calls {
296 if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
297 content.push(ContentPart::ToolCall {
298 id: func_call.id.clone(),
299 name: func_call.function.name.clone(),
300 arguments: func_call.function.arguments.clone(),
301 });
302 }
303 }
304 }
305
306 let finish_reason = if has_tool_calls {
308 FinishReason::ToolCalls
309 } else {
310 match choice.finish_reason {
311 Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
312 Some(OpenAIFinishReason::Length) => FinishReason::Length,
313 Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
314 Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
315 _ => FinishReason::Stop,
316 }
317 };
318
319 Ok(CompletionResponse {
320 message: Message {
321 role: Role::Assistant,
322 content,
323 },
324 usage: Usage {
325 prompt_tokens: response
326 .usage
327 .as_ref()
328 .map(|u| u.prompt_tokens as usize)
329 .unwrap_or(0),
330 completion_tokens: response
331 .usage
332 .as_ref()
333 .map(|u| u.completion_tokens as usize)
334 .unwrap_or(0),
335 total_tokens: response
336 .usage
337 .as_ref()
338 .map(|u| u.total_tokens as usize)
339 .unwrap_or(0),
340 ..Default::default()
341 },
342 finish_reason,
343 })
344 }
345
346 async fn complete_stream(
347 &self,
348 request: CompletionRequest,
349 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
350 tracing::debug!(
351 provider = %self.provider_name,
352 model = %request.model,
353 message_count = request.messages.len(),
354 "Starting streaming completion request"
355 );
356
357 let messages = Self::convert_messages(&request.messages)?;
358
359 let mut req_builder = CreateChatCompletionRequestArgs::default();
360 req_builder
361 .model(&request.model)
362 .messages(messages)
363 .stream(true);
364
365 if let Some(temp) = request.temperature {
366 req_builder.temperature(temp);
367 }
368
369 let stream = self
370 .client
371 .chat()
372 .create_stream(req_builder.build()?)
373 .await?;
374
375 Ok(stream
376 .map(|result| match result {
377 Ok(response) => {
378 if let Some(choice) = response.choices.first() {
379 if let Some(content) = &choice.delta.content {
380 return StreamChunk::Text(content.clone());
381 }
382 }
383 StreamChunk::Text(String::new())
384 }
385 Err(e) => StreamChunk::Error(e.to_string()),
386 })
387 .boxed())
388 }
389}