1use std::collections::BTreeMap;
4use std::fmt;
5use std::sync::Arc;
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::error::{Error, ProviderCompatibilityError, Result};
11use crate::json_payload::JsonPayload;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ProviderKind {
17 OpenAI,
19 Azure,
21 Zhipu,
23 MiniMax,
25 ZenMux,
27 Custom,
29}
30
31impl ProviderKind {
32 pub fn as_key(&self) -> &'static str {
34 match self {
35 Self::OpenAI => "openai",
36 Self::Azure => "azure",
37 Self::Zhipu => "zhipu",
38 Self::MiniMax => "minimax",
39 Self::ZenMux => "zenmux",
40 Self::Custom => "custom",
41 }
42 }
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum AuthScheme {
48 Bearer,
50 ApiKeyHeader,
52 QueryToken,
54 WebSocketSubprotocol,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum CompatibilityMode {
61 Passthrough,
63 Warn,
65 Strict,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
71#[serde(rename_all = "snake_case")]
72pub enum AzureAuthMode {
73 #[default]
75 ApiKey,
76 Bearer,
78}
79
80impl AzureAuthMode {
81 pub fn auth_scheme(self) -> AuthScheme {
83 match self {
84 Self::ApiKey => AuthScheme::ApiKeyHeader,
85 Self::Bearer => AuthScheme::Bearer,
86 }
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
92pub struct AzureOptions {
93 pub api_version: Option<String>,
95 pub deployment: Option<String>,
97 #[serde(default)]
99 pub auth_mode: AzureAuthMode,
100}
101
102impl AzureOptions {
103 pub fn new() -> Self {
105 Self::default()
106 }
107
108 pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
110 self.api_version = Some(api_version.into());
111 self
112 }
113
114 pub fn deployment(mut self, deployment: impl Into<String>) -> Self {
116 self.deployment = Some(deployment.into());
117 self
118 }
119
120 pub fn bearer_auth(mut self) -> Self {
122 self.auth_mode = AzureAuthMode::Bearer;
123 self
124 }
125
126 pub fn api_key_auth(mut self) -> Self {
128 self.auth_mode = AzureAuthMode::ApiKey;
129 self
130 }
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub struct CapabilitySet {
136 pub chat_completions: bool,
138 pub responses: bool,
140 pub models: bool,
142 pub streaming: bool,
144 pub tools: bool,
146 pub webhooks: bool,
148}
149
150const FULL_CAPABILITIES: CapabilitySet = CapabilitySet {
151 chat_completions: true,
152 responses: true,
153 models: true,
154 streaming: true,
155 tools: true,
156 webhooks: true,
157};
158
159const CHAT_ONLY_CAPABILITIES: CapabilitySet = CapabilitySet {
160 chat_completions: true,
161 responses: false,
162 models: true,
163 streaming: true,
164 tools: true,
165 webhooks: false,
166};
167
168#[derive(Debug, Clone)]
170pub struct RequestContext {
171 pub endpoint_id: &'static str,
173 pub path: String,
175 pub query: BTreeMap<String, String>,
177 pub headers: BTreeMap<String, String>,
179 pub body: Option<JsonPayload>,
181}
182
183pub trait ProviderProfile: Send + Sync {
185 fn kind(&self) -> ProviderKind;
187 fn default_base_url(&self) -> &str;
189 fn auth_scheme(&self) -> AuthScheme;
191 fn capabilities(&self) -> &'static CapabilitySet;
193 fn prepare_request(&self, ctx: &mut RequestContext) -> Result<()>;
195 fn adapt_error(&self, error: crate::ApiError) -> Error {
197 Error::Api(error)
198 }
199 fn validate_request(
201 &self,
202 endpoint_id: &'static str,
203 body: Option<&Value>,
204 mode: CompatibilityMode,
205 ) -> Result<()>;
206}
207
208#[derive(Clone)]
210pub struct Provider {
211 inner: Arc<dyn ProviderProfile>,
212}
213
214impl fmt::Debug for Provider {
215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 f.debug_struct("Provider")
217 .field("kind", &self.kind())
218 .field("default_base_url", &self.default_base_url())
219 .finish()
220 }
221}
222
223impl Provider {
224 pub fn openai() -> Self {
226 Self {
227 inner: Arc::new(OpenAiProfile),
228 }
229 }
230
231 pub fn azure() -> Self {
233 Self::azure_with_options(AzureOptions::default())
234 }
235
236 pub fn azure_with_options(options: AzureOptions) -> Self {
238 Self {
239 inner: Arc::new(AzureProfile::new(options)),
240 }
241 }
242
243 pub fn zhipu() -> Self {
245 Self {
246 inner: Arc::new(ZhipuProfile),
247 }
248 }
249
250 pub fn minimax() -> Self {
252 Self {
253 inner: Arc::new(MiniMaxProfile),
254 }
255 }
256
257 pub fn zenmux() -> Self {
259 Self {
260 inner: Arc::new(ZenMuxProfile),
261 }
262 }
263
264 pub fn custom<T>(profile: T) -> Self
266 where
267 T: ProviderProfile + 'static,
268 {
269 Self {
270 inner: Arc::new(profile),
271 }
272 }
273
274 pub fn kind(&self) -> ProviderKind {
276 self.inner.kind()
277 }
278
279 pub fn default_base_url(&self) -> &str {
281 self.inner.default_base_url()
282 }
283
284 pub fn profile(&self) -> &(dyn ProviderProfile + Send + Sync) {
286 self.inner.as_ref()
287 }
288}
289
290#[derive(Debug, Clone)]
292pub struct CustomProfile {
293 pub name: String,
295 pub base_url: String,
297 pub auth_scheme: AuthScheme,
299 pub capabilities: CapabilitySet,
301}
302
303impl ProviderProfile for CustomProfile {
304 fn kind(&self) -> ProviderKind {
305 ProviderKind::Custom
306 }
307
308 fn default_base_url(&self) -> &str {
309 &self.base_url
310 }
311
312 fn auth_scheme(&self) -> AuthScheme {
313 self.auth_scheme
314 }
315
316 fn capabilities(&self) -> &'static CapabilitySet {
317 Box::leak(Box::new(self.capabilities))
318 }
319
320 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
321 Ok(())
322 }
323
324 fn validate_request(
325 &self,
326 _endpoint_id: &'static str,
327 _body: Option<&Value>,
328 _mode: CompatibilityMode,
329 ) -> Result<()> {
330 Ok(())
331 }
332}
333
334#[derive(Debug, Clone, Default)]
335struct AzureProfile {
336 options: AzureOptions,
337}
338
339impl AzureProfile {
340 fn new(options: AzureOptions) -> Self {
341 Self { options }
342 }
343
344 fn api_version(&self) -> &str {
345 self.options
346 .api_version
347 .as_deref()
348 .filter(|value| !value.trim().is_empty())
349 .unwrap_or("2025-03-01-preview")
350 }
351
352 fn auth_scheme(&self) -> AuthScheme {
353 self.options.auth_mode.auth_scheme()
354 }
355
356 fn deployment_for(&self, ctx: &RequestContext) -> Option<String> {
357 if ctx.endpoint_id == "realtime.ws.connect" {
358 return ctx
359 .query
360 .get("deployment")
361 .cloned()
362 .or_else(|| self.options.deployment.clone())
363 .filter(|value| !value.trim().is_empty());
364 }
365
366 if !azure_deployment_path_required(&ctx.path) {
367 return None;
368 }
369
370 self.options
371 .deployment
372 .clone()
373 .or_else(|| {
374 ctx.body
375 .as_ref()
376 .and_then(|value| value.get("model"))
377 .and_then(Value::as_str)
378 .map(str::to_owned)
379 })
380 .filter(|value| !value.trim().is_empty())
381 }
382}
383
384#[derive(Debug, Clone, Copy)]
385struct OpenAiProfile;
386
387#[derive(Debug, Clone, Copy)]
388struct ZhipuProfile;
389
390#[derive(Debug, Clone, Copy)]
391struct MiniMaxProfile;
392
393#[derive(Debug, Clone, Copy)]
394struct ZenMuxProfile;
395
396impl ProviderProfile for OpenAiProfile {
397 fn kind(&self) -> ProviderKind {
398 ProviderKind::OpenAI
399 }
400
401 fn default_base_url(&self) -> &str {
402 "https://api.openai.com/v1"
403 }
404
405 fn auth_scheme(&self) -> AuthScheme {
406 AuthScheme::Bearer
407 }
408
409 fn capabilities(&self) -> &'static CapabilitySet {
410 &FULL_CAPABILITIES
411 }
412
413 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
414 Ok(())
415 }
416
417 fn validate_request(
418 &self,
419 _endpoint_id: &'static str,
420 _body: Option<&Value>,
421 _mode: CompatibilityMode,
422 ) -> Result<()> {
423 Ok(())
424 }
425}
426
427impl ProviderProfile for AzureProfile {
428 fn kind(&self) -> ProviderKind {
429 ProviderKind::Azure
430 }
431
432 fn default_base_url(&self) -> &str {
433 "https://example-resource.openai.azure.com"
434 }
435
436 fn auth_scheme(&self) -> AuthScheme {
437 self.auth_scheme()
438 }
439
440 fn capabilities(&self) -> &'static CapabilitySet {
441 &FULL_CAPABILITIES
442 }
443
444 fn prepare_request(&self, ctx: &mut RequestContext) -> Result<()> {
445 ctx.query
446 .entry("api-version".into())
447 .or_insert_with(|| self.api_version().into());
448
449 if !ctx.path.starts_with("/openai") {
450 ctx.path = format!("/openai{}", ctx.path);
451 }
452
453 if let Some(deployment) = self.deployment_for(ctx)
454 && ctx.endpoint_id == "realtime.ws.connect"
455 {
456 ctx.query.insert("deployment".into(), deployment);
457 } else if let Some(deployment) = self.deployment_for(ctx)
458 && !ctx.path.contains("/deployments/")
459 {
460 ctx.path =
461 ctx.path
462 .replacen("/openai/", &format!("/openai/deployments/{deployment}/"), 1);
463 }
464
465 Ok(())
466 }
467
468 fn validate_request(
469 &self,
470 _endpoint_id: &'static str,
471 _body: Option<&Value>,
472 _mode: CompatibilityMode,
473 ) -> Result<()> {
474 Ok(())
475 }
476}
477
478impl ProviderProfile for ZhipuProfile {
479 fn kind(&self) -> ProviderKind {
480 ProviderKind::Zhipu
481 }
482
483 fn default_base_url(&self) -> &str {
484 "https://open.bigmodel.cn/api/paas/v4"
485 }
486
487 fn auth_scheme(&self) -> AuthScheme {
488 AuthScheme::Bearer
489 }
490
491 fn capabilities(&self) -> &'static CapabilitySet {
492 &CHAT_ONLY_CAPABILITIES
493 }
494
495 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
496 Ok(())
497 }
498
499 fn validate_request(
500 &self,
501 _endpoint_id: &'static str,
502 _body: Option<&Value>,
503 _mode: CompatibilityMode,
504 ) -> Result<()> {
505 Ok(())
506 }
507}
508
509impl ProviderProfile for MiniMaxProfile {
510 fn kind(&self) -> ProviderKind {
511 ProviderKind::MiniMax
512 }
513
514 fn default_base_url(&self) -> &str {
515 "https://api.minimaxi.com/v1"
516 }
517
518 fn auth_scheme(&self) -> AuthScheme {
519 AuthScheme::Bearer
520 }
521
522 fn capabilities(&self) -> &'static CapabilitySet {
523 &CHAT_ONLY_CAPABILITIES
524 }
525
526 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
527 Ok(())
528 }
529
530 fn validate_request(
531 &self,
532 _endpoint_id: &'static str,
533 body: Option<&Value>,
534 mode: CompatibilityMode,
535 ) -> Result<()> {
536 if mode != CompatibilityMode::Strict {
537 return Ok(());
538 }
539
540 let Some(body) = body else {
541 return Ok(());
542 };
543
544 if let Some(value) = body.get("n").and_then(Value::as_i64)
545 && value != 1
546 {
547 return Err(ProviderCompatibilityError::new(
548 ProviderKind::MiniMax,
549 "MiniMax 在严格模式下仅支持 n = 1",
550 )
551 .into());
552 }
553
554 if contains_key(body, "function_call") {
555 return Err(ProviderCompatibilityError::new(
556 ProviderKind::MiniMax,
557 "MiniMax 在严格模式下不再支持旧版 function_call 字段,请改用 tools",
558 )
559 .into());
560 }
561
562 if contains_any_type(body, &["input_image", "image", "input_audio", "audio"]) {
563 return Err(ProviderCompatibilityError::new(
564 ProviderKind::MiniMax,
565 "MiniMax 在严格模式下不支持图像或音频输入",
566 )
567 .into());
568 }
569
570 Ok(())
571 }
572}
573
574impl ProviderProfile for ZenMuxProfile {
575 fn kind(&self) -> ProviderKind {
576 ProviderKind::ZenMux
577 }
578
579 fn default_base_url(&self) -> &str {
580 "https://zenmux.ai/api/v1"
581 }
582
583 fn auth_scheme(&self) -> AuthScheme {
584 AuthScheme::Bearer
585 }
586
587 fn capabilities(&self) -> &'static CapabilitySet {
588 &FULL_CAPABILITIES
589 }
590
591 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
592 Ok(())
593 }
594
595 fn validate_request(
596 &self,
597 _endpoint_id: &'static str,
598 body: Option<&Value>,
599 mode: CompatibilityMode,
600 ) -> Result<()> {
601 if mode != CompatibilityMode::Strict {
602 return Ok(());
603 }
604
605 let Some(model) = body
606 .and_then(|value| value.get("model"))
607 .and_then(Value::as_str)
608 else {
609 return Ok(());
610 };
611
612 if !model.contains('/') || model.starts_with('/') || model.ends_with('/') {
613 return Err(ProviderCompatibilityError::new(
614 ProviderKind::ZenMux,
615 "ZenMux 在严格模式下要求 model 采用 <provider>/<model_name> 形式",
616 )
617 .into());
618 }
619
620 Ok(())
621 }
622}
623
624fn contains_key(value: &Value, target: &str) -> bool {
625 match value {
626 Value::Object(map) => {
627 map.contains_key(target) || map.values().any(|value| contains_key(value, target))
628 }
629 Value::Array(values) => values.iter().any(|value| contains_key(value, target)),
630 _ => false,
631 }
632}
633
634fn contains_any_type(value: &Value, targets: &[&str]) -> bool {
635 match value {
636 Value::Object(map) => map.iter().any(|(key, nested)| {
637 (key == "type"
638 && nested
639 .as_str()
640 .is_some_and(|value| targets.contains(&value)))
641 || contains_any_type(nested, targets)
642 }),
643 Value::Array(values) => values.iter().any(|value| contains_any_type(value, targets)),
644 _ => false,
645 }
646}
647
648fn azure_deployment_path_required(path: &str) -> bool {
649 matches!(
650 path.trim_end_matches('/'),
651 "/completions"
652 | "/chat/completions"
653 | "/embeddings"
654 | "/audio/transcriptions"
655 | "/audio/translations"
656 | "/audio/speech"
657 | "/images/generations"
658 | "/images/edits"
659 | "/batches"
660 | "/openai/completions"
661 | "/openai/chat/completions"
662 | "/openai/embeddings"
663 | "/openai/audio/transcriptions"
664 | "/openai/audio/translations"
665 | "/openai/audio/speech"
666 | "/openai/images/generations"
667 | "/openai/images/edits"
668 | "/openai/batches"
669 )
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675
676 #[test]
677 fn test_should_use_zhipu_default_base_url() {
678 let provider = Provider::zhipu();
679 assert_eq!(
680 provider.default_base_url(),
681 "https://open.bigmodel.cn/api/paas/v4"
682 );
683 }
684
685 #[test]
686 fn test_should_use_minimax_default_base_url() {
687 let provider = Provider::minimax();
688 assert_eq!(provider.default_base_url(), "https://api.minimaxi.com/v1");
689 }
690
691 #[test]
692 fn test_should_use_zenmux_default_base_url() {
693 let provider = Provider::zenmux();
694 assert_eq!(provider.default_base_url(), "https://zenmux.ai/api/v1");
695 }
696
697 #[test]
698 fn test_should_validate_minimax_n_equals_one_in_strict_mode() {
699 let provider = Provider::minimax();
700 let body = serde_json::json!({
701 "model": "MiniMax-M2.7",
702 "messages": [{"role": "user", "content": "hello"}],
703 "n": 2
704 });
705 let error = provider
706 .profile()
707 .validate_request(
708 "chat.completions.create",
709 Some(&body),
710 CompatibilityMode::Strict,
711 )
712 .unwrap_err();
713 assert!(matches!(error, Error::ProviderCompatibility(_)));
714 }
715
716 #[test]
717 fn test_should_validate_zenmux_model_id_format_in_strict_mode() {
718 let provider = Provider::zenmux();
719 let body = serde_json::json!({
720 "model": "gpt-5",
721 "input": "hello"
722 });
723 let error = provider
724 .profile()
725 .validate_request("responses.create", Some(&body), CompatibilityMode::Strict)
726 .unwrap_err();
727 assert!(matches!(error, Error::ProviderCompatibility(_)));
728 }
729
730 #[test]
731 fn test_should_preserve_passthrough_mode_for_minimax() {
732 let provider = Provider::minimax();
733 let body = serde_json::json!({
734 "model": "MiniMax-M2.7",
735 "messages": [{"role": "user", "content": "hello"}],
736 "n": 3
737 });
738 provider
739 .profile()
740 .validate_request(
741 "chat.completions.create",
742 Some(&body),
743 CompatibilityMode::Passthrough,
744 )
745 .unwrap();
746 }
747
748 #[test]
749 fn test_should_inject_azure_api_version_and_prefix_path() {
750 let provider =
751 Provider::azure_with_options(AzureOptions::new().api_version("2024-02-15-preview"));
752 let mut context = RequestContext {
753 endpoint_id: "responses.create",
754 path: "/responses".into(),
755 query: BTreeMap::new(),
756 headers: BTreeMap::new(),
757 body: None,
758 };
759
760 provider.profile().prepare_request(&mut context).unwrap();
761
762 assert_eq!(context.path, "/openai/responses");
763 assert_eq!(
764 context.query.get("api-version").map(String::as_str),
765 Some("2024-02-15-preview")
766 );
767 }
768
769 #[test]
770 fn test_should_preserve_existing_azure_api_version_query() {
771 let provider = Provider::azure();
772 let mut context = RequestContext {
773 endpoint_id: "responses.create",
774 path: "/responses".into(),
775 query: BTreeMap::from([("api-version".into(), "custom-version".into())]),
776 headers: BTreeMap::new(),
777 body: None,
778 };
779
780 provider.profile().prepare_request(&mut context).unwrap();
781
782 assert_eq!(
783 context.query.get("api-version").map(String::as_str),
784 Some("custom-version")
785 );
786 }
787
788 #[test]
789 fn test_should_inject_azure_deployment_from_body_model() {
790 let provider = Provider::azure();
791 let mut context = RequestContext {
792 endpoint_id: "chat.completions.create",
793 path: "/chat/completions".into(),
794 query: BTreeMap::new(),
795 headers: BTreeMap::new(),
796 body: Some(
797 serde_json::json!({
798 "model": "gpt-4o-deployment"
799 })
800 .into(),
801 ),
802 };
803
804 provider.profile().prepare_request(&mut context).unwrap();
805
806 assert_eq!(
807 context.path,
808 "/openai/deployments/gpt-4o-deployment/chat/completions"
809 );
810 }
811
812 #[test]
813 fn test_should_inject_azure_realtime_deployment_query() {
814 let provider =
815 Provider::azure_with_options(AzureOptions::new().deployment("rt-deployment"));
816 let mut context = RequestContext {
817 endpoint_id: "realtime.ws.connect",
818 path: "/realtime".into(),
819 query: BTreeMap::new(),
820 headers: BTreeMap::new(),
821 body: None,
822 };
823
824 provider.profile().prepare_request(&mut context).unwrap();
825
826 assert_eq!(context.path, "/openai/realtime");
827 assert_eq!(
828 context.query.get("deployment").map(String::as_str),
829 Some("rt-deployment")
830 );
831 }
832
833 #[test]
834 fn test_should_switch_azure_auth_scheme_to_bearer() {
835 let provider = Provider::azure_with_options(AzureOptions::new().bearer_auth());
836 assert_eq!(provider.profile().auth_scheme(), AuthScheme::Bearer);
837 }
838}