1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9 Client,
10 config::OpenAIConfig,
11 types::chat::{
12 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16 CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17 FunctionObjectArgs,
18 },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24 client: Client<OpenAIConfig>,
25 provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("OpenAIProvider")
31 .field("provider_name", &self.provider_name)
32 .field("client", &"<async_openai::Client>")
33 .finish()
34 }
35}
36
37impl OpenAIProvider {
38 pub fn new(api_key: String) -> Result<Self> {
39 tracing::debug!(
40 provider = "openai",
41 api_key_len = api_key.len(),
42 "Creating OpenAI provider"
43 );
44 let config = OpenAIConfig::new().with_api_key(api_key);
45 Ok(Self {
46 client: Client::with_config(config),
47 provider_name: "openai".to_string(),
48 })
49 }
50
51 pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53 tracing::debug!(
54 provider = provider_name,
55 base_url = %base_url,
56 api_key_len = api_key.len(),
57 "Creating OpenAI-compatible provider"
58 );
59 let config = OpenAIConfig::new()
60 .with_api_key(api_key)
61 .with_api_base(base_url);
62 Ok(Self {
63 client: Client::with_config(config),
64 provider_name: provider_name.to_string(),
65 })
66 }
67
68 fn provider_default_models(&self) -> Vec<ModelInfo> {
70 let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71 "cerebras" => vec![
72 ("llama3.1-8b", "Llama 3.1 8B"),
73 ("llama-3.3-70b", "Llama 3.3 70B"),
74 ("qwen-3-32b", "Qwen 3 32B"),
75 ("gpt-oss-120b", "GPT-OSS 120B"),
76 ],
77
78 "minimax" => vec![
79 ("MiniMax-M2.5", "MiniMax M2.5"),
80 ("MiniMax-M2.5-lightning", "MiniMax M2.5 Lightning"),
81 ("MiniMax-M2.1", "MiniMax M2.1"),
82 ("MiniMax-M2.1-lightning", "MiniMax M2.1 Lightning"),
83 ("MiniMax-M2", "MiniMax M2"),
84 ],
85 "zhipuai" => vec![],
86 "novita" => vec![
87 ("qwen/qwen3-coder-next", "Qwen 3 Coder Next"),
88 ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
89 ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
90 ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
91 ],
92 _ => vec![],
93 };
94
95 models
96 .into_iter()
97 .map(|(id, name)| ModelInfo {
98 id: id.to_string(),
99 name: name.to_string(),
100 provider: self.provider_name.clone(),
101 context_window: 128_000,
102 max_output_tokens: Some(16_384),
103 supports_vision: false,
104 supports_tools: true,
105 supports_streaming: true,
106 input_cost_per_million: None,
107 output_cost_per_million: None,
108 })
109 .collect()
110 }
111
112 fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
113 let mut result = Vec::new();
114
115 for msg in messages {
116 let content = msg
117 .content
118 .iter()
119 .filter_map(|p| match p {
120 ContentPart::Text { text } => Some(text.clone()),
121 _ => None,
122 })
123 .collect::<Vec<_>>()
124 .join("\n");
125
126 match msg.role {
127 Role::System => {
128 result.push(
129 ChatCompletionRequestSystemMessageArgs::default()
130 .content(content)
131 .build()?
132 .into(),
133 );
134 }
135 Role::User => {
136 result.push(
137 ChatCompletionRequestUserMessageArgs::default()
138 .content(content)
139 .build()?
140 .into(),
141 );
142 }
143 Role::Assistant => {
144 let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
145 .content
146 .iter()
147 .filter_map(|p| match p {
148 ContentPart::ToolCall {
149 id,
150 name,
151 arguments,
152 } => Some(ChatCompletionMessageToolCalls::Function(
153 ChatCompletionMessageToolCall {
154 id: id.clone(),
155 function: FunctionCall {
156 name: name.clone(),
157 arguments: arguments.clone(),
158 },
159 },
160 )),
161 _ => None,
162 })
163 .collect();
164
165 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
166 if !content.is_empty() {
167 builder.content(content);
168 }
169 if !tool_calls.is_empty() {
170 builder.tool_calls(tool_calls);
171 }
172 result.push(builder.build()?.into());
173 }
174 Role::Tool => {
175 for part in &msg.content {
176 if let ContentPart::ToolResult {
177 tool_call_id,
178 content,
179 } = part
180 {
181 result.push(
182 ChatCompletionRequestToolMessageArgs::default()
183 .tool_call_id(tool_call_id.clone())
184 .content(content.clone())
185 .build()?
186 .into(),
187 );
188 }
189 }
190 }
191 }
192 }
193
194 Ok(result)
195 }
196
197 fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
198 let mut result = Vec::new();
199 for tool in tools {
200 result.push(ChatCompletionTools::Function(ChatCompletionTool {
201 function: FunctionObjectArgs::default()
202 .name(&tool.name)
203 .description(&tool.description)
204 .parameters(tool.parameters.clone())
205 .build()?,
206 }));
207 }
208 Ok(result)
209 }
210
211 fn is_minimax_chat_setting_error(error: &str) -> bool {
212 let normalized = error.to_ascii_lowercase();
213 normalized.contains("invalid chat setting")
214 || normalized.contains("(2013)")
215 || normalized.contains("code: 2013")
216 || normalized.contains("\"2013\"")
217 }
218}
219
220#[async_trait]
221impl Provider for OpenAIProvider {
222 fn name(&self) -> &str {
223 &self.provider_name
224 }
225
226 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
227 if self.provider_name != "openai" {
231 return Ok(self.provider_default_models());
232 }
233
234 Ok(vec![
236 ModelInfo {
237 id: "gpt-4o".to_string(),
238 name: "GPT-4o".to_string(),
239 provider: "openai".to_string(),
240 context_window: 128_000,
241 max_output_tokens: Some(16_384),
242 supports_vision: true,
243 supports_tools: true,
244 supports_streaming: true,
245 input_cost_per_million: Some(2.5),
246 output_cost_per_million: Some(10.0),
247 },
248 ModelInfo {
249 id: "gpt-4o-mini".to_string(),
250 name: "GPT-4o Mini".to_string(),
251 provider: "openai".to_string(),
252 context_window: 128_000,
253 max_output_tokens: Some(16_384),
254 supports_vision: true,
255 supports_tools: true,
256 supports_streaming: true,
257 input_cost_per_million: Some(0.15),
258 output_cost_per_million: Some(0.6),
259 },
260 ModelInfo {
261 id: "o1".to_string(),
262 name: "o1".to_string(),
263 provider: "openai".to_string(),
264 context_window: 200_000,
265 max_output_tokens: Some(100_000),
266 supports_vision: true,
267 supports_tools: true,
268 supports_streaming: true,
269 input_cost_per_million: Some(15.0),
270 output_cost_per_million: Some(60.0),
271 },
272 ])
273 }
274
275 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
276 let messages = Self::convert_messages(&request.messages)?;
277 let tools = Self::convert_tools(&request.tools)?;
278
279 let mut req_builder = CreateChatCompletionRequestArgs::default();
280 req_builder.model(&request.model).messages(messages.clone());
281
282 if !tools.is_empty() {
284 req_builder.tools(tools);
285 }
286 if let Some(temp) = request.temperature {
287 req_builder.temperature(temp);
288 }
289 if let Some(top_p) = request.top_p {
290 req_builder.top_p(top_p);
291 }
292 if let Some(max) = request.max_tokens {
293 if self.provider_name == "openai" {
294 req_builder.max_completion_tokens(max as u32);
295 } else {
296 req_builder.max_tokens(max as u32);
297 }
298 }
299
300 let primary_request = req_builder.build()?;
301 let response = match self.client.chat().create(primary_request).await {
302 Ok(response) => response,
303 Err(err)
304 if self.provider_name == "minimax"
305 && Self::is_minimax_chat_setting_error(&err.to_string()) =>
306 {
307 tracing::warn!(
308 provider = "minimax",
309 error = %err,
310 "MiniMax rejected chat settings; retrying with conservative defaults"
311 );
312
313 let mut fallback_builder = CreateChatCompletionRequestArgs::default();
314 fallback_builder.model(&request.model).messages(messages);
315 self.client.chat().create(fallback_builder.build()?).await?
316 }
317 Err(err) => return Err(err.into()),
318 };
319
320 let choice = response
321 .choices
322 .first()
323 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
324
325 let mut content = Vec::new();
326 let mut has_tool_calls = false;
327
328 if let Some(text) = &choice.message.content {
329 content.push(ContentPart::Text { text: text.clone() });
330 }
331 if let Some(tool_calls) = &choice.message.tool_calls {
332 has_tool_calls = !tool_calls.is_empty();
333 for tc in tool_calls {
334 if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
335 content.push(ContentPart::ToolCall {
336 id: func_call.id.clone(),
337 name: func_call.function.name.clone(),
338 arguments: func_call.function.arguments.clone(),
339 });
340 }
341 }
342 }
343
344 let finish_reason = if has_tool_calls {
346 FinishReason::ToolCalls
347 } else {
348 match choice.finish_reason {
349 Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
350 Some(OpenAIFinishReason::Length) => FinishReason::Length,
351 Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
352 Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
353 _ => FinishReason::Stop,
354 }
355 };
356
357 Ok(CompletionResponse {
358 message: Message {
359 role: Role::Assistant,
360 content,
361 },
362 usage: Usage {
363 prompt_tokens: response
364 .usage
365 .as_ref()
366 .map(|u| u.prompt_tokens as usize)
367 .unwrap_or(0),
368 completion_tokens: response
369 .usage
370 .as_ref()
371 .map(|u| u.completion_tokens as usize)
372 .unwrap_or(0),
373 total_tokens: response
374 .usage
375 .as_ref()
376 .map(|u| u.total_tokens as usize)
377 .unwrap_or(0),
378 ..Default::default()
379 },
380 finish_reason,
381 })
382 }
383
384 async fn complete_stream(
385 &self,
386 request: CompletionRequest,
387 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
388 tracing::debug!(
389 provider = %self.provider_name,
390 model = %request.model,
391 message_count = request.messages.len(),
392 "Starting streaming completion request"
393 );
394
395 let messages = Self::convert_messages(&request.messages)?;
396
397 let mut req_builder = CreateChatCompletionRequestArgs::default();
398 req_builder
399 .model(&request.model)
400 .messages(messages)
401 .stream(true);
402
403 if let Some(temp) = request.temperature {
404 req_builder.temperature(temp);
405 }
406
407 let stream = self
408 .client
409 .chat()
410 .create_stream(req_builder.build()?)
411 .await?;
412
413 Ok(stream
414 .map(|result| match result {
415 Ok(response) => {
416 if let Some(choice) = response.choices.first() {
417 if let Some(content) = &choice.delta.content {
418 return StreamChunk::Text(content.clone());
419 }
420 }
421 StreamChunk::Text(String::new())
422 }
423 Err(e) => StreamChunk::Error(e.to_string()),
424 })
425 .boxed())
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::OpenAIProvider;
432
433 #[test]
434 fn detects_minimax_chat_setting_error_variants() {
435 assert!(OpenAIProvider::is_minimax_chat_setting_error(
436 "bad_request_error: invalid params, invalid chat setting (2013)"
437 ));
438 assert!(OpenAIProvider::is_minimax_chat_setting_error(
439 "code: 2013 invalid params"
440 ));
441 assert!(!OpenAIProvider::is_minimax_chat_setting_error(
442 "rate limit exceeded"
443 ));
444 }
445}