codetether_agent/provider/
openai.rs1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9 Client,
10 config::OpenAIConfig,
11 types::chat::{
12 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16 CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17 FunctionObjectArgs,
18 },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24 client: Client<OpenAIConfig>,
25 provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("OpenAIProvider")
31 .field("provider_name", &self.provider_name)
32 .field("client", &"<async_openai::Client>")
33 .finish()
34 }
35}
36
37impl OpenAIProvider {
38 pub fn new(api_key: String) -> Result<Self> {
39 tracing::debug!(
40 provider = "openai",
41 api_key_len = api_key.len(),
42 "Creating OpenAI provider"
43 );
44 let config = OpenAIConfig::new().with_api_key(api_key);
45 Ok(Self {
46 client: Client::with_config(config),
47 provider_name: "openai".to_string(),
48 })
49 }
50
51 pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53 tracing::debug!(
54 provider = provider_name,
55 base_url = %base_url,
56 api_key_len = api_key.len(),
57 "Creating OpenAI-compatible provider"
58 );
59 let config = OpenAIConfig::new()
60 .with_api_key(api_key)
61 .with_api_base(base_url);
62 Ok(Self {
63 client: Client::with_config(config),
64 provider_name: provider_name.to_string(),
65 })
66 }
67
68 fn provider_default_models(&self) -> Vec<ModelInfo> {
70 let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71 "cerebras" => vec![
72 ("llama3.1-8b", "Llama 3.1 8B"),
73 ("llama-3.3-70b", "Llama 3.3 70B"),
74 ("qwen-3-32b", "Qwen 3 32B"),
75 ("gpt-oss-120b", "GPT-OSS 120B"),
76 ],
77
78 "minimax" => vec![
79 ("MiniMax-M1-80k", "MiniMax M1 80k"),
80 ("MiniMax-Text-01", "MiniMax Text 01"),
81 ],
82 "zhipuai" => vec![
83 ("glm-4.7", "GLM-4.7"),
84 ("glm-4-plus", "GLM-4 Plus"),
85 ("glm-4-flash", "GLM-4 Flash"),
86 ("glm-4-long", "GLM-4 Long"),
87 ],
88 "novita" => vec![
89 ("qwen/qwen3-coder-next", "Qwen 3 Coder Next"),
90 ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
91 ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
92 ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
93 ],
94 _ => vec![],
95 };
96
97 models
98 .into_iter()
99 .map(|(id, name)| ModelInfo {
100 id: id.to_string(),
101 name: name.to_string(),
102 provider: self.provider_name.clone(),
103 context_window: 128_000,
104 max_output_tokens: Some(16_384),
105 supports_vision: false,
106 supports_tools: true,
107 supports_streaming: true,
108 input_cost_per_million: None,
109 output_cost_per_million: None,
110 })
111 .collect()
112 }
113
114 fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
115 let mut result = Vec::new();
116
117 for msg in messages {
118 let content = msg
119 .content
120 .iter()
121 .filter_map(|p| match p {
122 ContentPart::Text { text } => Some(text.clone()),
123 _ => None,
124 })
125 .collect::<Vec<_>>()
126 .join("\n");
127
128 match msg.role {
129 Role::System => {
130 result.push(
131 ChatCompletionRequestSystemMessageArgs::default()
132 .content(content)
133 .build()?
134 .into(),
135 );
136 }
137 Role::User => {
138 result.push(
139 ChatCompletionRequestUserMessageArgs::default()
140 .content(content)
141 .build()?
142 .into(),
143 );
144 }
145 Role::Assistant => {
146 let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
147 .content
148 .iter()
149 .filter_map(|p| match p {
150 ContentPart::ToolCall {
151 id,
152 name,
153 arguments,
154 } => Some(ChatCompletionMessageToolCalls::Function(
155 ChatCompletionMessageToolCall {
156 id: id.clone(),
157 function: FunctionCall {
158 name: name.clone(),
159 arguments: arguments.clone(),
160 },
161 },
162 )),
163 _ => None,
164 })
165 .collect();
166
167 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
168 if !content.is_empty() {
169 builder.content(content);
170 }
171 if !tool_calls.is_empty() {
172 builder.tool_calls(tool_calls);
173 }
174 result.push(builder.build()?.into());
175 }
176 Role::Tool => {
177 for part in &msg.content {
178 if let ContentPart::ToolResult {
179 tool_call_id,
180 content,
181 } = part
182 {
183 result.push(
184 ChatCompletionRequestToolMessageArgs::default()
185 .tool_call_id(tool_call_id.clone())
186 .content(content.clone())
187 .build()?
188 .into(),
189 );
190 }
191 }
192 }
193 }
194 }
195
196 Ok(result)
197 }
198
199 fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
200 let mut result = Vec::new();
201 for tool in tools {
202 result.push(ChatCompletionTools::Function(ChatCompletionTool {
203 function: FunctionObjectArgs::default()
204 .name(&tool.name)
205 .description(&tool.description)
206 .parameters(tool.parameters.clone())
207 .build()?,
208 }));
209 }
210 Ok(result)
211 }
212}
213
214#[async_trait]
215impl Provider for OpenAIProvider {
216 fn name(&self) -> &str {
217 &self.provider_name
218 }
219
220 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
221 if self.provider_name != "openai" {
225 return Ok(self.provider_default_models());
226 }
227
228 Ok(vec![
230 ModelInfo {
231 id: "gpt-4o".to_string(),
232 name: "GPT-4o".to_string(),
233 provider: "openai".to_string(),
234 context_window: 128_000,
235 max_output_tokens: Some(16_384),
236 supports_vision: true,
237 supports_tools: true,
238 supports_streaming: true,
239 input_cost_per_million: Some(2.5),
240 output_cost_per_million: Some(10.0),
241 },
242 ModelInfo {
243 id: "gpt-4o-mini".to_string(),
244 name: "GPT-4o Mini".to_string(),
245 provider: "openai".to_string(),
246 context_window: 128_000,
247 max_output_tokens: Some(16_384),
248 supports_vision: true,
249 supports_tools: true,
250 supports_streaming: true,
251 input_cost_per_million: Some(0.15),
252 output_cost_per_million: Some(0.6),
253 },
254 ModelInfo {
255 id: "o1".to_string(),
256 name: "o1".to_string(),
257 provider: "openai".to_string(),
258 context_window: 200_000,
259 max_output_tokens: Some(100_000),
260 supports_vision: true,
261 supports_tools: true,
262 supports_streaming: true,
263 input_cost_per_million: Some(15.0),
264 output_cost_per_million: Some(60.0),
265 },
266 ])
267 }
268
269 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
270 let messages = Self::convert_messages(&request.messages)?;
271 let tools = Self::convert_tools(&request.tools)?;
272
273 let mut req_builder = CreateChatCompletionRequestArgs::default();
274 req_builder.model(&request.model).messages(messages);
275
276 if !tools.is_empty() {
278 req_builder.tools(tools);
279 }
280 if let Some(temp) = request.temperature {
281 req_builder.temperature(temp);
282 }
283 if let Some(max) = request.max_tokens {
284 req_builder.max_completion_tokens(max as u32);
285 }
286
287 let response = self.client.chat().create(req_builder.build()?).await?;
288
289 let choice = response
290 .choices
291 .first()
292 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
293
294 let mut content = Vec::new();
295 let mut has_tool_calls = false;
296
297 if let Some(text) = &choice.message.content {
298 content.push(ContentPart::Text { text: text.clone() });
299 }
300 if let Some(tool_calls) = &choice.message.tool_calls {
301 has_tool_calls = !tool_calls.is_empty();
302 for tc in tool_calls {
303 if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
304 content.push(ContentPart::ToolCall {
305 id: func_call.id.clone(),
306 name: func_call.function.name.clone(),
307 arguments: func_call.function.arguments.clone(),
308 });
309 }
310 }
311 }
312
313 let finish_reason = if has_tool_calls {
315 FinishReason::ToolCalls
316 } else {
317 match choice.finish_reason {
318 Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
319 Some(OpenAIFinishReason::Length) => FinishReason::Length,
320 Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
321 Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
322 _ => FinishReason::Stop,
323 }
324 };
325
326 Ok(CompletionResponse {
327 message: Message {
328 role: Role::Assistant,
329 content,
330 },
331 usage: Usage {
332 prompt_tokens: response
333 .usage
334 .as_ref()
335 .map(|u| u.prompt_tokens as usize)
336 .unwrap_or(0),
337 completion_tokens: response
338 .usage
339 .as_ref()
340 .map(|u| u.completion_tokens as usize)
341 .unwrap_or(0),
342 total_tokens: response
343 .usage
344 .as_ref()
345 .map(|u| u.total_tokens as usize)
346 .unwrap_or(0),
347 ..Default::default()
348 },
349 finish_reason,
350 })
351 }
352
353 async fn complete_stream(
354 &self,
355 request: CompletionRequest,
356 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
357 tracing::debug!(
358 provider = %self.provider_name,
359 model = %request.model,
360 message_count = request.messages.len(),
361 "Starting streaming completion request"
362 );
363
364 let messages = Self::convert_messages(&request.messages)?;
365
366 let mut req_builder = CreateChatCompletionRequestArgs::default();
367 req_builder
368 .model(&request.model)
369 .messages(messages)
370 .stream(true);
371
372 if let Some(temp) = request.temperature {
373 req_builder.temperature(temp);
374 }
375
376 let stream = self
377 .client
378 .chat()
379 .create_stream(req_builder.build()?)
380 .await?;
381
382 Ok(stream
383 .map(|result| match result {
384 Ok(response) => {
385 if let Some(choice) = response.choices.first() {
386 if let Some(content) = &choice.delta.content {
387 return StreamChunk::Text(content.clone());
388 }
389 }
390 StreamChunk::Text(String::new())
391 }
392 Err(e) => StreamChunk::Error(e.to_string()),
393 })
394 .boxed())
395 }
396}