1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, EmbeddingRequest, EmbeddingResponse,
5 FinishReason, Message, ModelInfo, Provider, Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9 Client,
10 config::OpenAIConfig,
11 types::chat::{
12 ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15 ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16 CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17 FunctionObjectArgs,
18 },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22use reqwest::Client as HttpClient;
23use serde_json::Value;
24
25pub struct OpenAIProvider {
26 client: Client<OpenAIConfig>,
27 provider_name: String,
28 api_key: Option<String>,
29 api_base: String,
30 http: HttpClient,
31}
32
33impl std::fmt::Debug for OpenAIProvider {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("OpenAIProvider")
36 .field("provider_name", &self.provider_name)
37 .field("api_base", &self.api_base)
38 .field("client", &"<async_openai::Client>")
39 .finish()
40 }
41}
42
43impl OpenAIProvider {
44 pub fn new(api_key: String) -> Result<Self> {
45 tracing::debug!(
46 provider = "openai",
47 api_key_len = api_key.len(),
48 "Creating OpenAI provider"
49 );
50 let config = OpenAIConfig::new().with_api_key(api_key.clone());
51 let api_base = "https://api.openai.com/v1".to_string();
52 Ok(Self {
53 client: Client::with_config(config),
54 provider_name: "openai".to_string(),
55 api_key: Some(api_key),
56 api_base,
57 http: HttpClient::builder()
58 .timeout(std::time::Duration::from_secs(45))
59 .build()?,
60 })
61 }
62
63 pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
65 Self::with_base_url_optional_key(Some(api_key), base_url, provider_name)
66 }
67
68 pub fn with_base_url_optional_key(
73 api_key: Option<String>,
74 base_url: String,
75 provider_name: &str,
76 ) -> Result<Self> {
77 let api_key = api_key.filter(|key| !key.trim().is_empty());
78 tracing::debug!(
79 provider = provider_name,
80 base_url = %base_url,
81 api_key_len = api_key.as_ref().map(|key| key.len()).unwrap_or(0),
82 "Creating OpenAI-compatible provider"
83 );
84 let config = OpenAIConfig::new()
85 .with_api_key(api_key.clone().unwrap_or_default())
86 .with_api_base(base_url.clone());
87 let api_base = base_url.trim_end_matches('/').to_string();
88 Ok(Self {
89 client: Client::with_config(config),
90 provider_name: provider_name.to_string(),
91 api_key,
92 api_base,
93 http: HttpClient::builder()
94 .timeout(std::time::Duration::from_secs(45))
95 .build()?,
96 })
97 }
98
99 fn provider_default_models(&self) -> Vec<ModelInfo> {
101 let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
102 "cerebras" => vec![
103 ("llama3.1-8b", "Llama 3.1 8B"),
104 ("llama-3.3-70b", "Llama 3.3 70B"),
105 ("qwen-3.5-32b", "Qwen 3.5 32B"),
106 ("gpt-oss-120b", "GPT-OSS 120B"),
107 ],
108
109 "minimax" => vec![
110 ("MiniMax-M2.5", "MiniMax M2.5"),
111 ("MiniMax-M2.5-highspeed", "MiniMax M2.5 Highspeed"),
112 ("MiniMax-M2.1", "MiniMax M2.1"),
113 ("MiniMax-M2.1-highspeed", "MiniMax M2.1 Highspeed"),
114 ("MiniMax-M2", "MiniMax M2"),
115 ],
116 "zhipuai" => vec![],
117 "novita" => vec![
118 ("Qwen/Qwen3.5-35B-A3B", "Qwen 3.5 35B A3B"),
119 ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
120 ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
121 ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
122 ],
123 _ => vec![],
124 };
125
126 models
127 .into_iter()
128 .map(|(id, name)| ModelInfo {
129 id: id.to_string(),
130 name: name.to_string(),
131 provider: self.provider_name.clone(),
132 context_window: 128_000,
133 max_output_tokens: Some(16_384),
134 supports_vision: false,
135 supports_tools: true,
136 supports_streaming: true,
137 input_cost_per_million: None,
138 output_cost_per_million: None,
139 })
140 .collect()
141 }
142
143 async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
144 let url = format!("{}/models", self.api_base);
145 let mut request = self.http.get(&url);
146 if let Some(api_key) = &self.api_key {
147 request = request.bearer_auth(api_key);
148 }
149
150 let response = match request.send().await {
151 Ok(response) => response,
152 Err(error) => {
153 tracing::debug!(
154 provider = %self.provider_name,
155 url = %url,
156 error = %error,
157 "Failed to fetch OpenAI-compatible /models endpoint"
158 );
159 return Vec::new();
160 }
161 };
162
163 let status = response.status();
164 if !status.is_success() {
165 tracing::debug!(
166 provider = %self.provider_name,
167 url = %url,
168 status = %status,
169 "OpenAI-compatible /models endpoint returned non-success"
170 );
171 return Vec::new();
172 }
173
174 let payload: Value = match response.json().await {
175 Ok(payload) => payload,
176 Err(error) => {
177 tracing::debug!(
178 provider = %self.provider_name,
179 url = %url,
180 error = %error,
181 "Failed to parse OpenAI-compatible /models response"
182 );
183 return Vec::new();
184 }
185 };
186
187 let models = Self::parse_models_payload(&payload, &self.provider_name);
188 if models.is_empty() {
189 tracing::debug!(
190 provider = %self.provider_name,
191 url = %url,
192 "OpenAI-compatible /models payload did not contain any model ids"
193 );
194 }
195 models
196 }
197
198 fn parse_models_payload(payload: &Value, provider_name: &str) -> Vec<ModelInfo> {
199 payload
200 .get("data")
201 .and_then(Value::as_array)
202 .into_iter()
203 .flatten()
204 .filter_map(|entry| Self::model_info_from_api_entry(entry, provider_name))
205 .collect()
206 }
207
208 fn model_info_from_api_entry(entry: &Value, provider_name: &str) -> Option<ModelInfo> {
209 let id = match entry {
210 Value::String(id) => id.trim(),
211 Value::Object(_) => entry.get("id").and_then(Value::as_str)?.trim(),
212 _ => return None,
213 };
214 if id.is_empty() {
215 return None;
216 }
217
218 let name = entry
219 .get("name")
220 .and_then(Value::as_str)
221 .map(str::trim)
222 .filter(|name| !name.is_empty())
223 .unwrap_or(id);
224
225 let supports_vision = entry
226 .get("supports_vision")
227 .and_then(Value::as_bool)
228 .or_else(|| {
229 entry
230 .get("input_modalities")
231 .and_then(Value::as_array)
232 .map(|modalities| {
233 modalities.iter().any(|modality| {
234 modality
235 .as_str()
236 .is_some_and(|modality| modality.eq_ignore_ascii_case("image"))
237 })
238 })
239 })
240 .unwrap_or(false);
241
242 Some(ModelInfo {
243 id: id.to_string(),
244 name: name.to_string(),
245 provider: provider_name.to_string(),
246 context_window: value_to_usize(
247 entry
248 .pointer("/limits/max_context_window_tokens")
249 .or_else(|| entry.get("context_window")),
250 )
251 .unwrap_or(128_000),
252 max_output_tokens: value_to_usize(
253 entry
254 .pointer("/limits/max_output_tokens")
255 .or_else(|| entry.get("max_output_tokens")),
256 ),
257 supports_vision,
258 supports_tools: entry
259 .get("supports_tools")
260 .and_then(Value::as_bool)
261 .unwrap_or(true),
262 supports_streaming: entry
263 .get("supports_streaming")
264 .and_then(Value::as_bool)
265 .unwrap_or(true),
266 input_cost_per_million: entry
267 .pointer("/pricing/input_cost_per_million")
268 .and_then(Value::as_f64),
269 output_cost_per_million: entry
270 .pointer("/pricing/output_cost_per_million")
271 .and_then(Value::as_f64),
272 })
273 }
274
275 fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
276 let mut result = Vec::new();
277
278 for msg in messages {
279 let content = msg
280 .content
281 .iter()
282 .filter_map(|p| match p {
283 ContentPart::Text { text } => Some(text.clone()),
284 _ => None,
285 })
286 .collect::<Vec<_>>()
287 .join("\n");
288
289 match msg.role {
290 Role::System => {
291 result.push(
292 ChatCompletionRequestSystemMessageArgs::default()
293 .content(content)
294 .build()?
295 .into(),
296 );
297 }
298 Role::User => {
299 result.push(
300 ChatCompletionRequestUserMessageArgs::default()
301 .content(content)
302 .build()?
303 .into(),
304 );
305 }
306 Role::Assistant => {
307 let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
308 .content
309 .iter()
310 .filter_map(|p| match p {
311 ContentPart::ToolCall {
312 id,
313 name,
314 arguments,
315 ..
316 } => Some(ChatCompletionMessageToolCalls::Function(
317 ChatCompletionMessageToolCall {
318 id: id.clone(),
319 function: FunctionCall {
320 name: name.clone(),
321 arguments: arguments.clone(),
322 },
323 },
324 )),
325 _ => None,
326 })
327 .collect();
328
329 let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
330 if !content.is_empty() {
331 builder.content(content);
332 }
333 if !tool_calls.is_empty() {
334 builder.tool_calls(tool_calls);
335 }
336 result.push(builder.build()?.into());
337 }
338 Role::Tool => {
339 for part in &msg.content {
340 if let ContentPart::ToolResult {
341 tool_call_id,
342 content,
343 } = part
344 {
345 result.push(
346 ChatCompletionRequestToolMessageArgs::default()
347 .tool_call_id(tool_call_id.clone())
348 .content(content.clone())
349 .build()?
350 .into(),
351 );
352 }
353 }
354 }
355 }
356 }
357
358 Ok(result)
359 }
360
361 fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
362 let mut result = Vec::new();
363 for tool in tools {
364 result.push(ChatCompletionTools::Function(ChatCompletionTool {
365 function: FunctionObjectArgs::default()
366 .name(&tool.name)
367 .description(&tool.description)
368 .parameters(tool.parameters.clone())
369 .build()?,
370 }));
371 }
372 Ok(result)
373 }
374
375 fn is_minimax_chat_setting_error(error: &str) -> bool {
376 let normalized = error.to_ascii_lowercase();
377 normalized.contains("invalid chat setting")
378 || normalized.contains("(2013)")
379 || normalized.contains("code: 2013")
380 || normalized.contains("\"2013\"")
381 }
382}
383
384#[async_trait]
385impl Provider for OpenAIProvider {
386 fn name(&self) -> &str {
387 &self.provider_name
388 }
389
390 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
391 if self.provider_name != "openai" {
395 let discovered = self.discover_models_from_api().await;
396 if !discovered.is_empty() {
397 return Ok(discovered);
398 }
399 return Ok(self.provider_default_models());
400 }
401
402 Ok(vec![
404 ModelInfo {
405 id: "gpt-4o".to_string(),
406 name: "GPT-4o".to_string(),
407 provider: "openai".to_string(),
408 context_window: 128_000,
409 max_output_tokens: Some(16_384),
410 supports_vision: true,
411 supports_tools: true,
412 supports_streaming: true,
413 input_cost_per_million: Some(2.5),
414 output_cost_per_million: Some(10.0),
415 },
416 ModelInfo {
417 id: "gpt-4o-mini".to_string(),
418 name: "GPT-4o Mini".to_string(),
419 provider: "openai".to_string(),
420 context_window: 128_000,
421 max_output_tokens: Some(16_384),
422 supports_vision: true,
423 supports_tools: true,
424 supports_streaming: true,
425 input_cost_per_million: Some(0.15),
426 output_cost_per_million: Some(0.6),
427 },
428 ModelInfo {
429 id: "o1".to_string(),
430 name: "o1".to_string(),
431 provider: "openai".to_string(),
432 context_window: 200_000,
433 max_output_tokens: Some(100_000),
434 supports_vision: true,
435 supports_tools: true,
436 supports_streaming: true,
437 input_cost_per_million: Some(15.0),
438 output_cost_per_million: Some(60.0),
439 },
440 ])
441 }
442
443 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
444 let messages = Self::convert_messages(&request.messages)?;
445 let tools = Self::convert_tools(&request.tools)?;
446
447 let mut req_builder = CreateChatCompletionRequestArgs::default();
448 req_builder.model(&request.model).messages(messages.clone());
449
450 if !tools.is_empty() {
452 req_builder.tools(tools);
453 }
454 if let Some(temp) = request.temperature {
455 req_builder.temperature(temp);
456 }
457 if let Some(top_p) = request.top_p {
458 req_builder.top_p(top_p);
459 }
460 if let Some(max) = request.max_tokens {
461 if self.provider_name == "openai" {
462 req_builder.max_completion_tokens(max as u32);
463 } else {
464 req_builder.max_tokens(max as u32);
465 }
466 }
467
468 let primary_request = req_builder.build()?;
469 let response = match self.client.chat().create(primary_request).await {
470 Ok(response) => response,
471 Err(err)
472 if self.provider_name == "minimax"
473 && Self::is_minimax_chat_setting_error(&err.to_string()) =>
474 {
475 tracing::warn!(
476 provider = "minimax",
477 error = %err,
478 "MiniMax rejected chat settings; retrying with conservative defaults"
479 );
480
481 let mut fallback_builder = CreateChatCompletionRequestArgs::default();
482 fallback_builder.model(&request.model).messages(messages);
483 self.client.chat().create(fallback_builder.build()?).await?
484 }
485 Err(err) => return Err(err.into()),
486 };
487
488 let choice = response
489 .choices
490 .first()
491 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
492
493 let mut content = Vec::new();
494 let mut has_tool_calls = false;
495
496 if let Some(text) = &choice.message.content {
497 content.push(ContentPart::Text { text: text.clone() });
498 }
499 if let Some(tool_calls) = &choice.message.tool_calls {
500 has_tool_calls = !tool_calls.is_empty();
501 for tc in tool_calls {
502 if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
503 content.push(ContentPart::ToolCall {
504 id: func_call.id.clone(),
505 name: func_call.function.name.clone(),
506 arguments: func_call.function.arguments.clone(),
507 thought_signature: None,
508 });
509 }
510 }
511 }
512
513 let finish_reason = if has_tool_calls {
515 FinishReason::ToolCalls
516 } else {
517 match choice.finish_reason {
518 Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
519 Some(OpenAIFinishReason::Length) => FinishReason::Length,
520 Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
521 Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
522 _ => FinishReason::Stop,
523 }
524 };
525
526 Ok(CompletionResponse {
527 message: Message {
528 role: Role::Assistant,
529 content,
530 },
531 usage: Usage {
532 prompt_tokens: response
533 .usage
534 .as_ref()
535 .map(|u| u.prompt_tokens as usize)
536 .unwrap_or(0),
537 completion_tokens: response
538 .usage
539 .as_ref()
540 .map(|u| u.completion_tokens as usize)
541 .unwrap_or(0),
542 total_tokens: response
543 .usage
544 .as_ref()
545 .map(|u| u.total_tokens as usize)
546 .unwrap_or(0),
547 ..Default::default()
548 },
549 finish_reason,
550 })
551 }
552
553 async fn complete_stream(
554 &self,
555 request: CompletionRequest,
556 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
557 tracing::debug!(
558 provider = %self.provider_name,
559 model = %request.model,
560 message_count = request.messages.len(),
561 "Starting streaming completion request"
562 );
563
564 let messages = Self::convert_messages(&request.messages)?;
565 let tools = Self::convert_tools(&request.tools)?;
566
567 let mut req_builder = CreateChatCompletionRequestArgs::default();
568 req_builder
569 .model(&request.model)
570 .messages(messages)
571 .stream(true);
572
573 if !tools.is_empty() {
574 req_builder.tools(tools);
575 }
576 if let Some(temp) = request.temperature {
577 req_builder.temperature(temp);
578 }
579 if let Some(max) = request.max_tokens {
580 if self.provider_name == "openai" {
581 req_builder.max_completion_tokens(max as u32);
582 } else {
583 req_builder.max_tokens(max as u32);
584 }
585 }
586
587 let stream = self
588 .client
589 .chat()
590 .create_stream(req_builder.build()?)
591 .await?;
592
593 Ok(stream
594 .flat_map(|result| {
595 let chunks: Vec<StreamChunk> = match result {
596 Ok(response) => {
597 let mut out = Vec::new();
598 if let Some(choice) = response.choices.first() {
599 if let Some(content) = &choice.delta.content {
601 if !content.is_empty() {
602 out.push(StreamChunk::Text(content.clone()));
603 }
604 }
605 if let Some(tool_calls) = &choice.delta.tool_calls {
607 for tc in tool_calls {
608 if let Some(func) = &tc.function {
609 if let Some(id) = &tc.id {
611 out.push(StreamChunk::ToolCallStart {
612 id: id.clone(),
613 name: func.name.clone().unwrap_or_default(),
614 });
615 }
616 if let Some(args) = &func.arguments {
618 if !args.is_empty() {
619 let id = tc.id.clone().unwrap_or_else(|| {
621 format!("tool_{}", tc.index)
622 });
623 out.push(StreamChunk::ToolCallDelta {
624 id,
625 arguments_delta: args.clone(),
626 });
627 }
628 }
629 }
630 }
631 }
632 }
633 out
634 }
635 Err(e) => vec![StreamChunk::Error(e.to_string())],
636 };
637 futures::stream::iter(chunks)
638 })
639 .boxed())
640 }
641
642 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
643 if request.inputs.is_empty() {
644 return Ok(EmbeddingResponse {
645 embeddings: Vec::new(),
646 usage: Usage::default(),
647 });
648 }
649
650 let url = format!("{}/embeddings", self.api_base.trim_end_matches('/'));
651 let body = OpenAIEmbeddingRequest {
652 model: request.model,
653 input: request.inputs,
654 };
655
656 let mut request_builder = self.http.post(url);
657 if let Some(api_key) = self.api_key.as_deref().filter(|key| !key.is_empty()) {
658 request_builder = request_builder.bearer_auth(api_key);
659 }
660 let response = request_builder.json(&body).send().await?;
661
662 let status = response.status();
663 let text = response.text().await?;
664 if !status.is_success() {
665 anyhow::bail!(
666 "embedding request failed ({status}): {}",
667 safe_char_prefix(&text, 500)
668 );
669 }
670
671 let mut payload: OpenAIEmbeddingResponse = serde_json::from_str(&text)?;
672 payload.data.sort_by_key(|item| item.index);
673 let embeddings: Vec<Vec<f32>> = payload
674 .data
675 .into_iter()
676 .map(|item| item.embedding)
677 .collect();
678
679 if embeddings.len() != body.input.len() {
680 anyhow::bail!(
681 "embedding response length mismatch: expected {}, got {}",
682 body.input.len(),
683 embeddings.len()
684 );
685 }
686
687 let prompt_tokens = payload.usage.prompt_tokens.unwrap_or(0) as usize;
688 let total_tokens = payload
689 .usage
690 .total_tokens
691 .unwrap_or(payload.usage.prompt_tokens.unwrap_or(0))
692 as usize;
693
694 Ok(EmbeddingResponse {
695 embeddings,
696 usage: Usage {
697 prompt_tokens,
698 completion_tokens: 0,
699 total_tokens,
700 ..Default::default()
701 },
702 })
703 }
704}
705
706fn value_to_usize(value: Option<&Value>) -> Option<usize> {
707 value
708 .and_then(Value::as_u64)
709 .and_then(|value| usize::try_from(value).ok())
710}
711
712fn safe_char_prefix(input: &str, max_chars: usize) -> String {
713 input.chars().take(max_chars).collect()
714}
715
716#[derive(Debug, serde::Serialize)]
717struct OpenAIEmbeddingRequest {
718 model: String,
719 input: Vec<String>,
720}
721
722#[derive(Debug, serde::Deserialize)]
723struct OpenAIEmbeddingResponse {
724 data: Vec<OpenAIEmbeddingData>,
725 #[serde(default)]
726 usage: OpenAIEmbeddingUsage,
727}
728
729#[derive(Debug, serde::Deserialize)]
730struct OpenAIEmbeddingData {
731 index: usize,
732 embedding: Vec<f32>,
733}
734
735#[derive(Debug, Default, serde::Deserialize)]
736struct OpenAIEmbeddingUsage {
737 #[serde(default)]
738 prompt_tokens: Option<u32>,
739 #[serde(default)]
740 total_tokens: Option<u32>,
741}
742
743#[cfg(test)]
744mod tests {
745 use super::{OpenAIProvider, Provider};
746 use serde_json::json;
747
748 #[test]
749 fn detects_minimax_chat_setting_error_variants() {
750 assert!(OpenAIProvider::is_minimax_chat_setting_error(
751 "bad_request_error: invalid params, invalid chat setting (2013)"
752 ));
753 assert!(OpenAIProvider::is_minimax_chat_setting_error(
754 "code: 2013 invalid params"
755 ));
756 assert!(!OpenAIProvider::is_minimax_chat_setting_error(
757 "rate limit exceeded"
758 ));
759 }
760
761 #[test]
762 fn supports_openai_compatible_provider_without_api_key() {
763 let provider = OpenAIProvider::with_base_url_optional_key(
764 None,
765 "http://localhost:8080/v1".to_string(),
766 "huggingface",
767 )
768 .expect("provider should initialize without API key");
769
770 assert_eq!(provider.name(), "huggingface");
771 }
772
773 #[test]
774 fn parses_openai_compatible_models_payload() {
775 let payload = json!({
776 "object": "list",
777 "data": [
778 {
779 "id": "GLM-5-Turbo",
780 "name": "GLM-5-Turbo",
781 "limits": {
782 "max_context_window_tokens": 200000,
783 "max_output_tokens": 16384
784 },
785 "input_modalities": ["text"]
786 }
787 ]
788 });
789
790 let models = OpenAIProvider::parse_models_payload(&payload, "custom-openapi");
791
792 assert_eq!(models.len(), 1);
793 assert_eq!(models[0].id, "GLM-5-Turbo");
794 assert_eq!(models[0].name, "GLM-5-Turbo");
795 assert_eq!(models[0].provider, "custom-openapi");
796 assert_eq!(models[0].context_window, 200_000);
797 assert_eq!(models[0].max_output_tokens, Some(16_384));
798 }
799
800 #[test]
801 fn parses_string_only_models_payload() {
802 let payload = json!({
803 "data": ["glm-5", "glm-5-turbo"]
804 });
805
806 let models = OpenAIProvider::parse_models_payload(&payload, "custom-openapi");
807
808 assert_eq!(models.len(), 2);
809 assert_eq!(models[0].id, "glm-5");
810 assert_eq!(models[1].id, "glm-5-turbo");
811 }
812}