1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9 Client,
10 config::OpenAIConfig,
11 types::chat::{
12 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16 CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17 FunctionObjectArgs,
18 },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22
23pub struct OpenAIProvider {
24 client: Client<OpenAIConfig>,
25 provider_name: String,
26}
27
28impl std::fmt::Debug for OpenAIProvider {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("OpenAIProvider")
31 .field("provider_name", &self.provider_name)
32 .field("client", &"<async_openai::Client>")
33 .finish()
34 }
35}
36
37impl OpenAIProvider {
38 pub fn new(api_key: String) -> Result<Self> {
39 tracing::debug!(
40 provider = "openai",
41 api_key_len = api_key.len(),
42 "Creating OpenAI provider"
43 );
44 let config = OpenAIConfig::new().with_api_key(api_key);
45 Ok(Self {
46 client: Client::with_config(config),
47 provider_name: "openai".to_string(),
48 })
49 }
50
51 pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
53 tracing::debug!(
54 provider = provider_name,
55 base_url = %base_url,
56 api_key_len = api_key.len(),
57 "Creating OpenAI-compatible provider"
58 );
59 let config = OpenAIConfig::new()
60 .with_api_key(api_key)
61 .with_api_base(base_url);
62 Ok(Self {
63 client: Client::with_config(config),
64 provider_name: provider_name.to_string(),
65 })
66 }
67
68 fn provider_default_models(&self) -> Vec<ModelInfo> {
70 let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
71 "cerebras" => vec![
72 ("llama3.1-8b", "Llama 3.1 8B"),
73 ("llama-3.3-70b", "Llama 3.3 70B"),
74 ("qwen-3-32b", "Qwen 3 32B"),
75 ("gpt-oss-120b", "GPT-OSS 120B"),
76 ],
77
78 "minimax" => vec![
79 ("MiniMax-M2.5", "MiniMax M2.5"),
80 ("MiniMax-M2.5-highspeed", "MiniMax M2.5 Highspeed"),
81 ("MiniMax-M2.1", "MiniMax M2.1"),
82 ("MiniMax-M2.1-highspeed", "MiniMax M2.1 Highspeed"),
83 ("MiniMax-M2", "MiniMax M2"),
84 ],
85 "zhipuai" => vec![],
86 "novita" => vec![
87 ("qwen/qwen3-coder-next", "Qwen 3 Coder Next"),
88 ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
89 ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
90 ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
91 ],
92 _ => vec![],
93 };
94
95 models
96 .into_iter()
97 .map(|(id, name)| ModelInfo {
98 id: id.to_string(),
99 name: name.to_string(),
100 provider: self.provider_name.clone(),
101 context_window: 128_000,
102 max_output_tokens: Some(16_384),
103 supports_vision: false,
104 supports_tools: true,
105 supports_streaming: true,
106 input_cost_per_million: None,
107 output_cost_per_million: None,
108 })
109 .collect()
110 }
111
112 fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
113 let mut result = Vec::new();
114
115 for msg in messages {
116 let content = msg
117 .content
118 .iter()
119 .filter_map(|p| match p {
120 ContentPart::Text { text } => Some(text.clone()),
121 _ => None,
122 })
123 .collect::<Vec<_>>()
124 .join("\n");
125
126 match msg.role {
127 Role::System => {
128 result.push(
129 ChatCompletionRequestSystemMessageArgs::default()
130 .content(content)
131 .build()?
132 .into(),
133 );
134 }
135 Role::User => {
136 result.push(
137 ChatCompletionRequestUserMessageArgs::default()
138 .content(content)
139 .build()?
140 .into(),
141 );
142 }
143 Role::Assistant => {
144 let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
145 .content
146 .iter()
147 .filter_map(|p| match p {
148 ContentPart::ToolCall {
149 id,
150 name,
151 arguments,
152 ..
153 } => Some(ChatCompletionMessageToolCalls::Function(
154 ChatCompletionMessageToolCall {
155 id: id.clone(),
156 function: FunctionCall {
157 name: name.clone(),
158 arguments: arguments.clone(),
159 },
160 },
161 )),
162 _ => None,
163 })
164 .collect();
165
166 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
167 if !content.is_empty() {
168 builder.content(content);
169 }
170 if !tool_calls.is_empty() {
171 builder.tool_calls(tool_calls);
172 }
173 result.push(builder.build()?.into());
174 }
175 Role::Tool => {
176 for part in &msg.content {
177 if let ContentPart::ToolResult {
178 tool_call_id,
179 content,
180 } = part
181 {
182 result.push(
183 ChatCompletionRequestToolMessageArgs::default()
184 .tool_call_id(tool_call_id.clone())
185 .content(content.clone())
186 .build()?
187 .into(),
188 );
189 }
190 }
191 }
192 }
193 }
194
195 Ok(result)
196 }
197
198 fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
199 let mut result = Vec::new();
200 for tool in tools {
201 result.push(ChatCompletionTools::Function(ChatCompletionTool {
202 function: FunctionObjectArgs::default()
203 .name(&tool.name)
204 .description(&tool.description)
205 .parameters(tool.parameters.clone())
206 .build()?,
207 }));
208 }
209 Ok(result)
210 }
211
212 fn is_minimax_chat_setting_error(error: &str) -> bool {
213 let normalized = error.to_ascii_lowercase();
214 normalized.contains("invalid chat setting")
215 || normalized.contains("(2013)")
216 || normalized.contains("code: 2013")
217 || normalized.contains("\"2013\"")
218 }
219}
220
221#[async_trait]
222impl Provider for OpenAIProvider {
223 fn name(&self) -> &str {
224 &self.provider_name
225 }
226
227 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
228 if self.provider_name != "openai" {
232 return Ok(self.provider_default_models());
233 }
234
235 Ok(vec![
237 ModelInfo {
238 id: "gpt-4o".to_string(),
239 name: "GPT-4o".to_string(),
240 provider: "openai".to_string(),
241 context_window: 128_000,
242 max_output_tokens: Some(16_384),
243 supports_vision: true,
244 supports_tools: true,
245 supports_streaming: true,
246 input_cost_per_million: Some(2.5),
247 output_cost_per_million: Some(10.0),
248 },
249 ModelInfo {
250 id: "gpt-4o-mini".to_string(),
251 name: "GPT-4o Mini".to_string(),
252 provider: "openai".to_string(),
253 context_window: 128_000,
254 max_output_tokens: Some(16_384),
255 supports_vision: true,
256 supports_tools: true,
257 supports_streaming: true,
258 input_cost_per_million: Some(0.15),
259 output_cost_per_million: Some(0.6),
260 },
261 ModelInfo {
262 id: "o1".to_string(),
263 name: "o1".to_string(),
264 provider: "openai".to_string(),
265 context_window: 200_000,
266 max_output_tokens: Some(100_000),
267 supports_vision: true,
268 supports_tools: true,
269 supports_streaming: true,
270 input_cost_per_million: Some(15.0),
271 output_cost_per_million: Some(60.0),
272 },
273 ])
274 }
275
276 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
277 let messages = Self::convert_messages(&request.messages)?;
278 let tools = Self::convert_tools(&request.tools)?;
279
280 let mut req_builder = CreateChatCompletionRequestArgs::default();
281 req_builder.model(&request.model).messages(messages.clone());
282
283 if !tools.is_empty() {
285 req_builder.tools(tools);
286 }
287 if let Some(temp) = request.temperature {
288 req_builder.temperature(temp);
289 }
290 if let Some(top_p) = request.top_p {
291 req_builder.top_p(top_p);
292 }
293 if let Some(max) = request.max_tokens {
294 if self.provider_name == "openai" {
295 req_builder.max_completion_tokens(max as u32);
296 } else {
297 req_builder.max_tokens(max as u32);
298 }
299 }
300
301 let primary_request = req_builder.build()?;
302 let response = match self.client.chat().create(primary_request).await {
303 Ok(response) => response,
304 Err(err)
305 if self.provider_name == "minimax"
306 && Self::is_minimax_chat_setting_error(&err.to_string()) =>
307 {
308 tracing::warn!(
309 provider = "minimax",
310 error = %err,
311 "MiniMax rejected chat settings; retrying with conservative defaults"
312 );
313
314 let mut fallback_builder = CreateChatCompletionRequestArgs::default();
315 fallback_builder.model(&request.model).messages(messages);
316 self.client.chat().create(fallback_builder.build()?).await?
317 }
318 Err(err) => return Err(err.into()),
319 };
320
321 let choice = response
322 .choices
323 .first()
324 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
325
326 let mut content = Vec::new();
327 let mut has_tool_calls = false;
328
329 if let Some(text) = &choice.message.content {
330 content.push(ContentPart::Text { text: text.clone() });
331 }
332 if let Some(tool_calls) = &choice.message.tool_calls {
333 has_tool_calls = !tool_calls.is_empty();
334 for tc in tool_calls {
335 if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
336 content.push(ContentPart::ToolCall {
337 id: func_call.id.clone(),
338 name: func_call.function.name.clone(),
339 arguments: func_call.function.arguments.clone(),
340 thought_signature: None,
341 });
342 }
343 }
344 }
345
346 let finish_reason = if has_tool_calls {
348 FinishReason::ToolCalls
349 } else {
350 match choice.finish_reason {
351 Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
352 Some(OpenAIFinishReason::Length) => FinishReason::Length,
353 Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
354 Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
355 _ => FinishReason::Stop,
356 }
357 };
358
359 Ok(CompletionResponse {
360 message: Message {
361 role: Role::Assistant,
362 content,
363 },
364 usage: Usage {
365 prompt_tokens: response
366 .usage
367 .as_ref()
368 .map(|u| u.prompt_tokens as usize)
369 .unwrap_or(0),
370 completion_tokens: response
371 .usage
372 .as_ref()
373 .map(|u| u.completion_tokens as usize)
374 .unwrap_or(0),
375 total_tokens: response
376 .usage
377 .as_ref()
378 .map(|u| u.total_tokens as usize)
379 .unwrap_or(0),
380 ..Default::default()
381 },
382 finish_reason,
383 })
384 }
385
386 async fn complete_stream(
387 &self,
388 request: CompletionRequest,
389 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
390 tracing::debug!(
391 provider = %self.provider_name,
392 model = %request.model,
393 message_count = request.messages.len(),
394 "Starting streaming completion request"
395 );
396
397 let messages = Self::convert_messages(&request.messages)?;
398
399 let mut req_builder = CreateChatCompletionRequestArgs::default();
400 req_builder
401 .model(&request.model)
402 .messages(messages)
403 .stream(true);
404
405 if let Some(temp) = request.temperature {
406 req_builder.temperature(temp);
407 }
408
409 let stream = self
410 .client
411 .chat()
412 .create_stream(req_builder.build()?)
413 .await?;
414
415 Ok(stream
416 .map(|result| match result {
417 Ok(response) => {
418 if let Some(choice) = response.choices.first() {
419 if let Some(content) = &choice.delta.content {
420 return StreamChunk::Text(content.clone());
421 }
422 }
423 StreamChunk::Text(String::new())
424 }
425 Err(e) => StreamChunk::Error(e.to_string()),
426 })
427 .boxed())
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::OpenAIProvider;
434
435 #[test]
436 fn detects_minimax_chat_setting_error_variants() {
437 assert!(OpenAIProvider::is_minimax_chat_setting_error(
438 "bad_request_error: invalid params, invalid chat setting (2013)"
439 ));
440 assert!(OpenAIProvider::is_minimax_chat_setting_error(
441 "code: 2013 invalid params"
442 ));
443 assert!(!OpenAIProvider::is_minimax_chat_setting_error(
444 "rate limit exceeded"
445 ));
446 }
447}