codetether_agent/provider/
openai.rs1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9 Client,
10 config::OpenAIConfig,
11 types::chat::{
12 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16 CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17 FunctionObjectArgs,
18 },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24 client: Client<OpenAIConfig>,
25 provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("OpenAIProvider")
31 .field("provider_name", &self.provider_name)
32 .field("client", &"<async_openai::Client>")
33 .finish()
34 }
35}
36
37impl OpenAIProvider {
38 pub fn new(api_key: String) -> Result<Self> {
39 tracing::debug!(
40 provider = "openai",
41 api_key_len = api_key.len(),
42 "Creating OpenAI provider"
43 );
44 let config = OpenAIConfig::new().with_api_key(api_key);
45 Ok(Self {
46 client: Client::with_config(config),
47 provider_name: "openai".to_string(),
48 })
49 }
50
51 pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53 tracing::debug!(
54 provider = provider_name,
55 base_url = %base_url,
56 api_key_len = api_key.len(),
57 "Creating OpenAI-compatible provider"
58 );
59 let config = OpenAIConfig::new()
60 .with_api_key(api_key)
61 .with_api_base(base_url);
62 Ok(Self {
63 client: Client::with_config(config),
64 provider_name: provider_name.to_string(),
65 })
66 }
67
68 fn provider_default_models(&self) -> Vec<ModelInfo> {
70 let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71 "cerebras" => vec![
72 ("llama3.1-8b", "Llama 3.1 8B"),
73 ("llama-3.3-70b", "Llama 3.3 70B"),
74 ("qwen-3-32b", "Qwen 3 32B"),
75 ("gpt-oss-120b", "GPT-OSS 120B"),
76 ],
77
78 "minimax" => vec![
79 ("MiniMax-M1-80k", "MiniMax M1 80k"),
80 ("MiniMax-Text-01", "MiniMax Text 01"),
81 ],
82 "zhipuai" => vec![],
83 "novita" => vec![
84 ("qwen/qwen3-coder-next", "Qwen 3 Coder Next"),
85 ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
86 ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
87 ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
88 ],
89 _ => vec![],
90 };
91
92 models
93 .into_iter()
94 .map(|(id, name)| ModelInfo {
95 id: id.to_string(),
96 name: name.to_string(),
97 provider: self.provider_name.clone(),
98 context_window: 128_000,
99 max_output_tokens: Some(16_384),
100 supports_vision: false,
101 supports_tools: true,
102 supports_streaming: true,
103 input_cost_per_million: None,
104 output_cost_per_million: None,
105 })
106 .collect()
107 }
108
109 fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
110 let mut result = Vec::new();
111
112 for msg in messages {
113 let content = msg
114 .content
115 .iter()
116 .filter_map(|p| match p {
117 ContentPart::Text { text } => Some(text.clone()),
118 _ => None,
119 })
120 .collect::<Vec<_>>()
121 .join("\n");
122
123 match msg.role {
124 Role::System => {
125 result.push(
126 ChatCompletionRequestSystemMessageArgs::default()
127 .content(content)
128 .build()?
129 .into(),
130 );
131 }
132 Role::User => {
133 result.push(
134 ChatCompletionRequestUserMessageArgs::default()
135 .content(content)
136 .build()?
137 .into(),
138 );
139 }
140 Role::Assistant => {
141 let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
142 .content
143 .iter()
144 .filter_map(|p| match p {
145 ContentPart::ToolCall {
146 id,
147 name,
148 arguments,
149 } => Some(ChatCompletionMessageToolCalls::Function(
150 ChatCompletionMessageToolCall {
151 id: id.clone(),
152 function: FunctionCall {
153 name: name.clone(),
154 arguments: arguments.clone(),
155 },
156 },
157 )),
158 _ => None,
159 })
160 .collect();
161
162 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
163 if !content.is_empty() {
164 builder.content(content);
165 }
166 if !tool_calls.is_empty() {
167 builder.tool_calls(tool_calls);
168 }
169 result.push(builder.build()?.into());
170 }
171 Role::Tool => {
172 for part in &msg.content {
173 if let ContentPart::ToolResult {
174 tool_call_id,
175 content,
176 } = part
177 {
178 result.push(
179 ChatCompletionRequestToolMessageArgs::default()
180 .tool_call_id(tool_call_id.clone())
181 .content(content.clone())
182 .build()?
183 .into(),
184 );
185 }
186 }
187 }
188 }
189 }
190
191 Ok(result)
192 }
193
194 fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
195 let mut result = Vec::new();
196 for tool in tools {
197 result.push(ChatCompletionTools::Function(ChatCompletionTool {
198 function: FunctionObjectArgs::default()
199 .name(&tool.name)
200 .description(&tool.description)
201 .parameters(tool.parameters.clone())
202 .build()?,
203 }));
204 }
205 Ok(result)
206 }
207}
208
209#[async_trait]
210impl Provider for OpenAIProvider {
211 fn name(&self) -> &str {
212 &self.provider_name
213 }
214
215 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
216 if self.provider_name != "openai" {
220 return Ok(self.provider_default_models());
221 }
222
223 Ok(vec![
225 ModelInfo {
226 id: "gpt-4o".to_string(),
227 name: "GPT-4o".to_string(),
228 provider: "openai".to_string(),
229 context_window: 128_000,
230 max_output_tokens: Some(16_384),
231 supports_vision: true,
232 supports_tools: true,
233 supports_streaming: true,
234 input_cost_per_million: Some(2.5),
235 output_cost_per_million: Some(10.0),
236 },
237 ModelInfo {
238 id: "gpt-4o-mini".to_string(),
239 name: "GPT-4o Mini".to_string(),
240 provider: "openai".to_string(),
241 context_window: 128_000,
242 max_output_tokens: Some(16_384),
243 supports_vision: true,
244 supports_tools: true,
245 supports_streaming: true,
246 input_cost_per_million: Some(0.15),
247 output_cost_per_million: Some(0.6),
248 },
249 ModelInfo {
250 id: "o1".to_string(),
251 name: "o1".to_string(),
252 provider: "openai".to_string(),
253 context_window: 200_000,
254 max_output_tokens: Some(100_000),
255 supports_vision: true,
256 supports_tools: true,
257 supports_streaming: true,
258 input_cost_per_million: Some(15.0),
259 output_cost_per_million: Some(60.0),
260 },
261 ])
262 }
263
264 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
265 let messages = Self::convert_messages(&request.messages)?;
266 let tools = Self::convert_tools(&request.tools)?;
267
268 let mut req_builder = CreateChatCompletionRequestArgs::default();
269 req_builder.model(&request.model).messages(messages);
270
271 if !tools.is_empty() {
273 req_builder.tools(tools);
274 }
275 if let Some(temp) = request.temperature {
276 req_builder.temperature(temp);
277 }
278 if let Some(max) = request.max_tokens {
279 req_builder.max_completion_tokens(max as u32);
280 }
281
282 let response = self.client.chat().create(req_builder.build()?).await?;
283
284 let choice = response
285 .choices
286 .first()
287 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
288
289 let mut content = Vec::new();
290 let mut has_tool_calls = false;
291
292 if let Some(text) = &choice.message.content {
293 content.push(ContentPart::Text { text: text.clone() });
294 }
295 if let Some(tool_calls) = &choice.message.tool_calls {
296 has_tool_calls = !tool_calls.is_empty();
297 for tc in tool_calls {
298 if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
299 content.push(ContentPart::ToolCall {
300 id: func_call.id.clone(),
301 name: func_call.function.name.clone(),
302 arguments: func_call.function.arguments.clone(),
303 });
304 }
305 }
306 }
307
308 let finish_reason = if has_tool_calls {
310 FinishReason::ToolCalls
311 } else {
312 match choice.finish_reason {
313 Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
314 Some(OpenAIFinishReason::Length) => FinishReason::Length,
315 Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
316 Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
317 _ => FinishReason::Stop,
318 }
319 };
320
321 Ok(CompletionResponse {
322 message: Message {
323 role: Role::Assistant,
324 content,
325 },
326 usage: Usage {
327 prompt_tokens: response
328 .usage
329 .as_ref()
330 .map(|u| u.prompt_tokens as usize)
331 .unwrap_or(0),
332 completion_tokens: response
333 .usage
334 .as_ref()
335 .map(|u| u.completion_tokens as usize)
336 .unwrap_or(0),
337 total_tokens: response
338 .usage
339 .as_ref()
340 .map(|u| u.total_tokens as usize)
341 .unwrap_or(0),
342 ..Default::default()
343 },
344 finish_reason,
345 })
346 }
347
348 async fn complete_stream(
349 &self,
350 request: CompletionRequest,
351 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
352 tracing::debug!(
353 provider = %self.provider_name,
354 model = %request.model,
355 message_count = request.messages.len(),
356 "Starting streaming completion request"
357 );
358
359 let messages = Self::convert_messages(&request.messages)?;
360
361 let mut req_builder = CreateChatCompletionRequestArgs::default();
362 req_builder
363 .model(&request.model)
364 .messages(messages)
365 .stream(true);
366
367 if let Some(temp) = request.temperature {
368 req_builder.temperature(temp);
369 }
370
371 let stream = self
372 .client
373 .chat()
374 .create_stream(req_builder.build()?)
375 .await?;
376
377 Ok(stream
378 .map(|result| match result {
379 Ok(response) => {
380 if let Some(choice) = response.choices.first() {
381 if let Some(content) = &choice.delta.content {
382 return StreamChunk::Text(content.clone());
383 }
384 }
385 StreamChunk::Text(String::new())
386 }
387 Err(e) => StreamChunk::Error(e.to_string()),
388 })
389 .boxed())
390 }
391}