1use std::collections::BTreeMap;
4use std::fmt;
5use std::sync::Arc;
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::error::{Error, ProviderCompatibilityError, Result};
11use crate::json_payload::JsonPayload;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ProviderKind {
17 OpenAI,
19 Azure,
21 Kimi,
23 DeepSeek,
25 Zhipu,
27 MiniMax,
29 ZenMux,
31 Custom,
33}
34
35impl ProviderKind {
36 pub fn as_key(&self) -> &'static str {
38 match self {
39 Self::OpenAI => "openai",
40 Self::Azure => "azure",
41 Self::Kimi => "kimi",
42 Self::DeepSeek => "deepseek",
43 Self::Zhipu => "zhipu",
44 Self::MiniMax => "minimax",
45 Self::ZenMux => "zenmux",
46 Self::Custom => "custom",
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum AuthScheme {
54 Bearer,
56 ApiKeyHeader,
58 QueryToken,
60 WebSocketSubprotocol,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum CompatibilityMode {
67 Passthrough,
69 Warn,
71 Strict,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
77#[serde(rename_all = "snake_case")]
78pub enum AzureAuthMode {
79 #[default]
81 ApiKey,
82 Bearer,
84}
85
86impl AzureAuthMode {
87 pub fn auth_scheme(self) -> AuthScheme {
89 match self {
90 Self::ApiKey => AuthScheme::ApiKeyHeader,
91 Self::Bearer => AuthScheme::Bearer,
92 }
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
98pub struct AzureOptions {
99 pub api_version: Option<String>,
101 pub deployment: Option<String>,
103 #[serde(default)]
105 pub auth_mode: AzureAuthMode,
106}
107
108impl AzureOptions {
109 pub fn new() -> Self {
111 Self::default()
112 }
113
114 pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
116 self.api_version = Some(api_version.into());
117 self
118 }
119
120 pub fn deployment(mut self, deployment: impl Into<String>) -> Self {
122 self.deployment = Some(deployment.into());
123 self
124 }
125
126 pub fn bearer_auth(mut self) -> Self {
128 self.auth_mode = AzureAuthMode::Bearer;
129 self
130 }
131
132 pub fn api_key_auth(mut self) -> Self {
134 self.auth_mode = AzureAuthMode::ApiKey;
135 self
136 }
137}
138
139#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141pub struct CapabilitySet {
142 pub chat_completions: bool,
144 pub responses: bool,
146 pub models: bool,
148 pub streaming: bool,
150 pub tools: bool,
152 pub webhooks: bool,
154}
155
156const FULL_CAPABILITIES: CapabilitySet = CapabilitySet {
157 chat_completions: true,
158 responses: true,
159 models: true,
160 streaming: true,
161 tools: true,
162 webhooks: true,
163};
164
165const CHAT_ONLY_CAPABILITIES: CapabilitySet = CapabilitySet {
166 chat_completions: true,
167 responses: false,
168 models: true,
169 streaming: true,
170 tools: true,
171 webhooks: false,
172};
173
174#[derive(Debug, Clone)]
176pub struct RequestContext {
177 pub endpoint_id: &'static str,
179 pub path: String,
181 pub query: BTreeMap<String, String>,
183 pub headers: BTreeMap<String, String>,
185 pub body: Option<JsonPayload>,
187}
188
189pub trait ProviderProfile: Send + Sync {
191 fn kind(&self) -> ProviderKind;
193 fn default_base_url(&self) -> &str;
195 fn auth_scheme(&self) -> AuthScheme;
197 fn capabilities(&self) -> &'static CapabilitySet;
199 fn prepare_request(&self, ctx: &mut RequestContext) -> Result<()>;
201 fn adapt_error(&self, error: crate::ApiError) -> Error {
203 Error::Api(error)
204 }
205 fn validate_request(
207 &self,
208 endpoint_id: &'static str,
209 body: Option<&Value>,
210 mode: CompatibilityMode,
211 ) -> Result<()>;
212}
213
214#[derive(Clone)]
216pub struct Provider {
217 inner: Arc<dyn ProviderProfile>,
218}
219
220impl fmt::Debug for Provider {
221 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222 f.debug_struct("Provider")
223 .field("kind", &self.kind())
224 .field("default_base_url", &self.default_base_url())
225 .finish()
226 }
227}
228
229impl Provider {
230 pub fn openai() -> Self {
232 Self {
233 inner: Arc::new(OpenAiProfile),
234 }
235 }
236
237 pub fn azure() -> Self {
239 Self::azure_with_options(AzureOptions::default())
240 }
241
242 pub fn azure_with_options(options: AzureOptions) -> Self {
244 Self {
245 inner: Arc::new(AzureProfile::new(options)),
246 }
247 }
248
249 pub fn kimi() -> Self {
251 Self {
252 inner: Arc::new(KimiProfile),
253 }
254 }
255
256 pub fn deepseek() -> Self {
258 Self {
259 inner: Arc::new(DeepSeekProfile),
260 }
261 }
262
263 pub fn zhipu() -> Self {
265 Self {
266 inner: Arc::new(ZhipuProfile),
267 }
268 }
269
270 pub fn minimax() -> Self {
272 Self {
273 inner: Arc::new(MiniMaxProfile),
274 }
275 }
276
277 pub fn zenmux() -> Self {
279 Self {
280 inner: Arc::new(ZenMuxProfile),
281 }
282 }
283
284 pub fn custom<T>(profile: T) -> Self
286 where
287 T: ProviderProfile + 'static,
288 {
289 Self {
290 inner: Arc::new(profile),
291 }
292 }
293
294 pub fn kind(&self) -> ProviderKind {
296 self.inner.kind()
297 }
298
299 pub fn default_base_url(&self) -> &str {
301 self.inner.default_base_url()
302 }
303
304 pub fn profile(&self) -> &(dyn ProviderProfile + Send + Sync) {
306 self.inner.as_ref()
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct CustomProfile {
313 pub name: String,
315 pub base_url: String,
317 pub auth_scheme: AuthScheme,
319 pub capabilities: CapabilitySet,
321}
322
323impl ProviderProfile for CustomProfile {
324 fn kind(&self) -> ProviderKind {
325 ProviderKind::Custom
326 }
327
328 fn default_base_url(&self) -> &str {
329 &self.base_url
330 }
331
332 fn auth_scheme(&self) -> AuthScheme {
333 self.auth_scheme
334 }
335
336 fn capabilities(&self) -> &'static CapabilitySet {
337 Box::leak(Box::new(self.capabilities))
338 }
339
340 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
341 Ok(())
342 }
343
344 fn validate_request(
345 &self,
346 _endpoint_id: &'static str,
347 _body: Option<&Value>,
348 _mode: CompatibilityMode,
349 ) -> Result<()> {
350 Ok(())
351 }
352}
353
354#[derive(Debug, Clone, Default)]
355struct AzureProfile {
356 options: AzureOptions,
357}
358
359impl AzureProfile {
360 fn new(options: AzureOptions) -> Self {
361 Self { options }
362 }
363
364 fn api_version(&self) -> &str {
365 self.options
366 .api_version
367 .as_deref()
368 .filter(|value| !value.trim().is_empty())
369 .unwrap_or("2025-03-01-preview")
370 }
371
372 fn auth_scheme(&self) -> AuthScheme {
373 self.options.auth_mode.auth_scheme()
374 }
375
376 fn deployment_for(&self, ctx: &RequestContext) -> Option<String> {
377 if ctx.endpoint_id == "realtime.ws.connect" {
378 return ctx
379 .query
380 .get("deployment")
381 .cloned()
382 .or_else(|| self.options.deployment.clone())
383 .filter(|value| !value.trim().is_empty());
384 }
385
386 if !azure_deployment_path_required(&ctx.path) {
387 return None;
388 }
389
390 self.options
391 .deployment
392 .clone()
393 .or_else(|| {
394 ctx.body
395 .as_ref()
396 .and_then(|value| value.get("model"))
397 .and_then(Value::as_str)
398 .map(str::to_owned)
399 })
400 .filter(|value| !value.trim().is_empty())
401 }
402}
403
404#[derive(Debug, Clone, Copy)]
405struct OpenAiProfile;
406
407#[derive(Debug, Clone, Copy)]
408struct KimiProfile;
409
410#[derive(Debug, Clone, Copy)]
411struct DeepSeekProfile;
412
413#[derive(Debug, Clone, Copy)]
414struct ZhipuProfile;
415
416#[derive(Debug, Clone, Copy)]
417struct MiniMaxProfile;
418
419#[derive(Debug, Clone, Copy)]
420struct ZenMuxProfile;
421
422impl ProviderProfile for OpenAiProfile {
423 fn kind(&self) -> ProviderKind {
424 ProviderKind::OpenAI
425 }
426
427 fn default_base_url(&self) -> &str {
428 "https://api.openai.com/v1"
429 }
430
431 fn auth_scheme(&self) -> AuthScheme {
432 AuthScheme::Bearer
433 }
434
435 fn capabilities(&self) -> &'static CapabilitySet {
436 &FULL_CAPABILITIES
437 }
438
439 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
440 Ok(())
441 }
442
443 fn validate_request(
444 &self,
445 _endpoint_id: &'static str,
446 _body: Option<&Value>,
447 _mode: CompatibilityMode,
448 ) -> Result<()> {
449 Ok(())
450 }
451}
452
453impl ProviderProfile for AzureProfile {
454 fn kind(&self) -> ProviderKind {
455 ProviderKind::Azure
456 }
457
458 fn default_base_url(&self) -> &str {
459 "https://example-resource.openai.azure.com"
460 }
461
462 fn auth_scheme(&self) -> AuthScheme {
463 self.auth_scheme()
464 }
465
466 fn capabilities(&self) -> &'static CapabilitySet {
467 &FULL_CAPABILITIES
468 }
469
470 fn prepare_request(&self, ctx: &mut RequestContext) -> Result<()> {
471 ctx.query
472 .entry("api-version".into())
473 .or_insert_with(|| self.api_version().into());
474
475 if !ctx.path.starts_with("/openai") {
476 ctx.path = format!("/openai{}", ctx.path);
477 }
478
479 if let Some(deployment) = self.deployment_for(ctx)
480 && ctx.endpoint_id == "realtime.ws.connect"
481 {
482 ctx.query.insert("deployment".into(), deployment);
483 } else if let Some(deployment) = self.deployment_for(ctx)
484 && !ctx.path.contains("/deployments/")
485 {
486 ctx.path =
487 ctx.path
488 .replacen("/openai/", &format!("/openai/deployments/{deployment}/"), 1);
489 }
490
491 Ok(())
492 }
493
494 fn validate_request(
495 &self,
496 _endpoint_id: &'static str,
497 _body: Option<&Value>,
498 _mode: CompatibilityMode,
499 ) -> Result<()> {
500 Ok(())
501 }
502}
503
504impl ProviderProfile for KimiProfile {
505 fn kind(&self) -> ProviderKind {
506 ProviderKind::Kimi
507 }
508
509 fn default_base_url(&self) -> &str {
510 "https://api.moonshot.ai/v1"
511 }
512
513 fn auth_scheme(&self) -> AuthScheme {
514 AuthScheme::Bearer
515 }
516
517 fn capabilities(&self) -> &'static CapabilitySet {
518 &CHAT_ONLY_CAPABILITIES
519 }
520
521 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
522 Ok(())
523 }
524
525 fn validate_request(
526 &self,
527 _endpoint_id: &'static str,
528 _body: Option<&Value>,
529 _mode: CompatibilityMode,
530 ) -> Result<()> {
531 Ok(())
532 }
533}
534
535impl ProviderProfile for DeepSeekProfile {
536 fn kind(&self) -> ProviderKind {
537 ProviderKind::DeepSeek
538 }
539
540 fn default_base_url(&self) -> &str {
541 "https://api.deepseek.com"
542 }
543
544 fn auth_scheme(&self) -> AuthScheme {
545 AuthScheme::Bearer
546 }
547
548 fn capabilities(&self) -> &'static CapabilitySet {
549 &CHAT_ONLY_CAPABILITIES
550 }
551
552 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
553 Ok(())
554 }
555
556 fn validate_request(
557 &self,
558 _endpoint_id: &'static str,
559 _body: Option<&Value>,
560 _mode: CompatibilityMode,
561 ) -> Result<()> {
562 Ok(())
563 }
564}
565
566impl ProviderProfile for ZhipuProfile {
567 fn kind(&self) -> ProviderKind {
568 ProviderKind::Zhipu
569 }
570
571 fn default_base_url(&self) -> &str {
572 "https://open.bigmodel.cn/api/paas/v4"
573 }
574
575 fn auth_scheme(&self) -> AuthScheme {
576 AuthScheme::Bearer
577 }
578
579 fn capabilities(&self) -> &'static CapabilitySet {
580 &CHAT_ONLY_CAPABILITIES
581 }
582
583 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
584 Ok(())
585 }
586
587 fn validate_request(
588 &self,
589 _endpoint_id: &'static str,
590 _body: Option<&Value>,
591 _mode: CompatibilityMode,
592 ) -> Result<()> {
593 Ok(())
594 }
595}
596
597impl ProviderProfile for MiniMaxProfile {
598 fn kind(&self) -> ProviderKind {
599 ProviderKind::MiniMax
600 }
601
602 fn default_base_url(&self) -> &str {
603 "https://api.minimaxi.com/v1"
604 }
605
606 fn auth_scheme(&self) -> AuthScheme {
607 AuthScheme::Bearer
608 }
609
610 fn capabilities(&self) -> &'static CapabilitySet {
611 &CHAT_ONLY_CAPABILITIES
612 }
613
614 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
615 Ok(())
616 }
617
618 fn validate_request(
619 &self,
620 _endpoint_id: &'static str,
621 body: Option<&Value>,
622 mode: CompatibilityMode,
623 ) -> Result<()> {
624 if mode != CompatibilityMode::Strict {
625 return Ok(());
626 }
627
628 let Some(body) = body else {
629 return Ok(());
630 };
631
632 if let Some(value) = body.get("n").and_then(Value::as_i64)
633 && value != 1
634 {
635 return Err(ProviderCompatibilityError::new(
636 ProviderKind::MiniMax,
637 "MiniMax 在严格模式下仅支持 n = 1",
638 )
639 .into());
640 }
641
642 if contains_key(body, "function_call") {
643 return Err(ProviderCompatibilityError::new(
644 ProviderKind::MiniMax,
645 "MiniMax 在严格模式下不再支持旧版 function_call 字段,请改用 tools",
646 )
647 .into());
648 }
649
650 if contains_any_type(body, &["input_image", "image", "input_audio", "audio"]) {
651 return Err(ProviderCompatibilityError::new(
652 ProviderKind::MiniMax,
653 "MiniMax 在严格模式下不支持图像或音频输入",
654 )
655 .into());
656 }
657
658 Ok(())
659 }
660}
661
662impl ProviderProfile for ZenMuxProfile {
663 fn kind(&self) -> ProviderKind {
664 ProviderKind::ZenMux
665 }
666
667 fn default_base_url(&self) -> &str {
668 "https://zenmux.ai/api/v1"
669 }
670
671 fn auth_scheme(&self) -> AuthScheme {
672 AuthScheme::Bearer
673 }
674
675 fn capabilities(&self) -> &'static CapabilitySet {
676 &FULL_CAPABILITIES
677 }
678
679 fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
680 Ok(())
681 }
682
683 fn validate_request(
684 &self,
685 _endpoint_id: &'static str,
686 body: Option<&Value>,
687 mode: CompatibilityMode,
688 ) -> Result<()> {
689 if mode != CompatibilityMode::Strict {
690 return Ok(());
691 }
692
693 let Some(model) = body
694 .and_then(|value| value.get("model"))
695 .and_then(Value::as_str)
696 else {
697 return Ok(());
698 };
699
700 if !model.contains('/') || model.starts_with('/') || model.ends_with('/') {
701 return Err(ProviderCompatibilityError::new(
702 ProviderKind::ZenMux,
703 "ZenMux 在严格模式下要求 model 采用 <provider>/<model_name> 形式",
704 )
705 .into());
706 }
707
708 Ok(())
709 }
710}
711
712fn contains_key(value: &Value, target: &str) -> bool {
713 match value {
714 Value::Object(map) => {
715 map.contains_key(target) || map.values().any(|value| contains_key(value, target))
716 }
717 Value::Array(values) => values.iter().any(|value| contains_key(value, target)),
718 _ => false,
719 }
720}
721
722fn contains_any_type(value: &Value, targets: &[&str]) -> bool {
723 match value {
724 Value::Object(map) => map.iter().any(|(key, nested)| {
725 (key == "type"
726 && nested
727 .as_str()
728 .is_some_and(|value| targets.contains(&value)))
729 || contains_any_type(nested, targets)
730 }),
731 Value::Array(values) => values.iter().any(|value| contains_any_type(value, targets)),
732 _ => false,
733 }
734}
735
736fn azure_deployment_path_required(path: &str) -> bool {
737 matches!(
738 path.trim_end_matches('/'),
739 "/completions"
740 | "/chat/completions"
741 | "/embeddings"
742 | "/audio/transcriptions"
743 | "/audio/translations"
744 | "/audio/speech"
745 | "/images/generations"
746 | "/images/edits"
747 | "/batches"
748 | "/openai/completions"
749 | "/openai/chat/completions"
750 | "/openai/embeddings"
751 | "/openai/audio/transcriptions"
752 | "/openai/audio/translations"
753 | "/openai/audio/speech"
754 | "/openai/images/generations"
755 | "/openai/images/edits"
756 | "/openai/batches"
757 )
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763
764 #[test]
765 fn test_should_use_kimi_default_base_url() {
766 let provider = Provider::kimi();
767 assert_eq!(provider.default_base_url(), "https://api.moonshot.ai/v1");
768 }
769
770 #[test]
771 fn test_should_use_deepseek_default_base_url() {
772 let provider = Provider::deepseek();
773 assert_eq!(provider.default_base_url(), "https://api.deepseek.com");
774 }
775
776 #[test]
777 fn test_should_use_zhipu_default_base_url() {
778 let provider = Provider::zhipu();
779 assert_eq!(
780 provider.default_base_url(),
781 "https://open.bigmodel.cn/api/paas/v4"
782 );
783 }
784
785 #[test]
786 fn test_should_use_minimax_default_base_url() {
787 let provider = Provider::minimax();
788 assert_eq!(provider.default_base_url(), "https://api.minimaxi.com/v1");
789 }
790
791 #[test]
792 fn test_should_use_zenmux_default_base_url() {
793 let provider = Provider::zenmux();
794 assert_eq!(provider.default_base_url(), "https://zenmux.ai/api/v1");
795 }
796
797 #[test]
798 fn test_should_validate_minimax_n_equals_one_in_strict_mode() {
799 let provider = Provider::minimax();
800 let body = serde_json::json!({
801 "model": "MiniMax-M2.7",
802 "messages": [{"role": "user", "content": "hello"}],
803 "n": 2
804 });
805 let error = provider
806 .profile()
807 .validate_request(
808 "chat.completions.create",
809 Some(&body),
810 CompatibilityMode::Strict,
811 )
812 .unwrap_err();
813 assert!(matches!(error, Error::ProviderCompatibility(_)));
814 }
815
816 #[test]
817 fn test_should_validate_zenmux_model_id_format_in_strict_mode() {
818 let provider = Provider::zenmux();
819 let body = serde_json::json!({
820 "model": "gpt-5",
821 "input": "hello"
822 });
823 let error = provider
824 .profile()
825 .validate_request("responses.create", Some(&body), CompatibilityMode::Strict)
826 .unwrap_err();
827 assert!(matches!(error, Error::ProviderCompatibility(_)));
828 }
829
830 #[test]
831 fn test_should_preserve_passthrough_mode_for_minimax() {
832 let provider = Provider::minimax();
833 let body = serde_json::json!({
834 "model": "MiniMax-M2.7",
835 "messages": [{"role": "user", "content": "hello"}],
836 "n": 3
837 });
838 provider
839 .profile()
840 .validate_request(
841 "chat.completions.create",
842 Some(&body),
843 CompatibilityMode::Passthrough,
844 )
845 .unwrap();
846 }
847
848 #[test]
849 fn test_should_inject_azure_api_version_and_prefix_path() {
850 let provider =
851 Provider::azure_with_options(AzureOptions::new().api_version("2024-02-15-preview"));
852 let mut context = RequestContext {
853 endpoint_id: "responses.create",
854 path: "/responses".into(),
855 query: BTreeMap::new(),
856 headers: BTreeMap::new(),
857 body: None,
858 };
859
860 provider.profile().prepare_request(&mut context).unwrap();
861
862 assert_eq!(context.path, "/openai/responses");
863 assert_eq!(
864 context.query.get("api-version").map(String::as_str),
865 Some("2024-02-15-preview")
866 );
867 }
868
869 #[test]
870 fn test_should_preserve_existing_azure_api_version_query() {
871 let provider = Provider::azure();
872 let mut context = RequestContext {
873 endpoint_id: "responses.create",
874 path: "/responses".into(),
875 query: BTreeMap::from([("api-version".into(), "custom-version".into())]),
876 headers: BTreeMap::new(),
877 body: None,
878 };
879
880 provider.profile().prepare_request(&mut context).unwrap();
881
882 assert_eq!(
883 context.query.get("api-version").map(String::as_str),
884 Some("custom-version")
885 );
886 }
887
888 #[test]
889 fn test_should_inject_azure_deployment_from_body_model() {
890 let provider = Provider::azure();
891 let mut context = RequestContext {
892 endpoint_id: "chat.completions.create",
893 path: "/chat/completions".into(),
894 query: BTreeMap::new(),
895 headers: BTreeMap::new(),
896 body: Some(
897 serde_json::json!({
898 "model": "gpt-4o-deployment"
899 })
900 .into(),
901 ),
902 };
903
904 provider.profile().prepare_request(&mut context).unwrap();
905
906 assert_eq!(
907 context.path,
908 "/openai/deployments/gpt-4o-deployment/chat/completions"
909 );
910 }
911
912 #[test]
913 fn test_should_inject_azure_realtime_deployment_query() {
914 let provider =
915 Provider::azure_with_options(AzureOptions::new().deployment("rt-deployment"));
916 let mut context = RequestContext {
917 endpoint_id: "realtime.ws.connect",
918 path: "/realtime".into(),
919 query: BTreeMap::new(),
920 headers: BTreeMap::new(),
921 body: None,
922 };
923
924 provider.profile().prepare_request(&mut context).unwrap();
925
926 assert_eq!(context.path, "/openai/realtime");
927 assert_eq!(
928 context.query.get("deployment").map(String::as_str),
929 Some("rt-deployment")
930 );
931 }
932
933 #[test]
934 fn test_should_switch_azure_auth_scheme_to_bearer() {
935 let provider = Provider::azure_with_options(AzureOptions::new().bearer_auth());
936 assert_eq!(provider.profile().auth_scheme(), AuthScheme::Bearer);
937 }
938}