Skip to main content

openai_core/providers/
mod.rs

1//! Provider 兼容层。
2
3use std::collections::BTreeMap;
4use std::fmt;
5use std::sync::Arc;
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::error::{Error, ProviderCompatibilityError, Result};
11use crate::json_payload::JsonPayload;
12
13/// 表示支持的 Provider 类型。
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ProviderKind {
17    /// 官方 OpenAI Provider。
18    OpenAI,
19    /// Azure OpenAI Provider。
20    Azure,
21    /// 智谱兼容 Provider。
22    Zhipu,
23    /// MiniMax 兼容 Provider。
24    MiniMax,
25    /// ZenMux 兼容 Provider。
26    ZenMux,
27    /// 自定义 Provider。
28    Custom,
29}
30
31impl ProviderKind {
32    /// 返回 provider 对应的小写键。
33    pub fn as_key(&self) -> &'static str {
34        match self {
35            Self::OpenAI => "openai",
36            Self::Azure => "azure",
37            Self::Zhipu => "zhipu",
38            Self::MiniMax => "minimax",
39            Self::ZenMux => "zenmux",
40            Self::Custom => "custom",
41        }
42    }
43}
44
45/// 表示 Provider 的认证方案。
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum AuthScheme {
48    /// 使用 `Authorization: Bearer <token>`。
49    Bearer,
50    /// 使用 `api-key: <token>`。
51    ApiKeyHeader,
52    /// 使用查询参数传递令牌。
53    QueryToken,
54    /// 使用 WebSocket 子协议传递令牌。
55    WebSocketSubprotocol,
56}
57
58/// 表示兼容性校验模式。
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum CompatibilityMode {
61    /// 尽可能透传未知字段。
62    Passthrough,
63    /// 对已知风险发出警告。
64    Warn,
65    /// 对已知不兼容字段直接报错。
66    Strict,
67}
68
69/// 表示 Azure OpenAI 的认证模式。
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
71#[serde(rename_all = "snake_case")]
72pub enum AzureAuthMode {
73    /// 使用 `api-key` 请求头。
74    #[default]
75    ApiKey,
76    /// 使用 `Authorization: Bearer <token>`。
77    Bearer,
78}
79
80impl AzureAuthMode {
81    /// 转换为底层通用认证方案。
82    pub fn auth_scheme(self) -> AuthScheme {
83        match self {
84            Self::ApiKey => AuthScheme::ApiKeyHeader,
85            Self::Bearer => AuthScheme::Bearer,
86        }
87    }
88}
89
90/// 表示 Azure Provider 的可配置选项。
91#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
92pub struct AzureOptions {
93    /// Azure OpenAI `api-version`。
94    pub api_version: Option<String>,
95    /// 默认 deployment 名称。
96    pub deployment: Option<String>,
97    /// Azure 认证模式。
98    #[serde(default)]
99    pub auth_mode: AzureAuthMode,
100}
101
102impl AzureOptions {
103    /// 创建默认 Azure 选项。
104    pub fn new() -> Self {
105        Self::default()
106    }
107
108    /// 设置 `api-version`。
109    pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
110        self.api_version = Some(api_version.into());
111        self
112    }
113
114    /// 设置默认 deployment。
115    pub fn deployment(mut self, deployment: impl Into<String>) -> Self {
116        self.deployment = Some(deployment.into());
117        self
118    }
119
120    /// 切换为 Bearer Token 认证。
121    pub fn bearer_auth(mut self) -> Self {
122        self.auth_mode = AzureAuthMode::Bearer;
123        self
124    }
125
126    /// 切换为 `api-key` 认证。
127    pub fn api_key_auth(mut self) -> Self {
128        self.auth_mode = AzureAuthMode::ApiKey;
129        self
130    }
131}
132
133/// 表示 Provider 的能力集合。
134#[derive(Debug, Clone, Copy, PartialEq, Eq)]
135pub struct CapabilitySet {
136    /// 是否支持聊天补全。
137    pub chat_completions: bool,
138    /// 是否支持 Responses API。
139    pub responses: bool,
140    /// 是否支持模型列表。
141    pub models: bool,
142    /// 是否支持 SSE 流。
143    pub streaming: bool,
144    /// 是否支持工具调用。
145    pub tools: bool,
146    /// 是否支持 Webhook。
147    pub webhooks: bool,
148}
149
150const FULL_CAPABILITIES: CapabilitySet = CapabilitySet {
151    chat_completions: true,
152    responses: true,
153    models: true,
154    streaming: true,
155    tools: true,
156    webhooks: true,
157};
158
159const CHAT_ONLY_CAPABILITIES: CapabilitySet = CapabilitySet {
160    chat_completions: true,
161    responses: false,
162    models: true,
163    streaming: true,
164    tools: true,
165    webhooks: false,
166};
167
168/// 表示 Provider 在发送请求前可修改的上下文。
169#[derive(Debug, Clone)]
170pub struct RequestContext {
171    /// 逻辑端点 ID。
172    pub endpoint_id: &'static str,
173    /// HTTP 路径。
174    pub path: String,
175    /// 查询参数。
176    pub query: BTreeMap<String, String>,
177    /// 请求头。
178    pub headers: BTreeMap<String, String>,
179    /// JSON 请求体。
180    pub body: Option<JsonPayload>,
181}
182
183/// ProviderProfile 用于屏蔽不同兼容 Provider 的差异。
184pub trait ProviderProfile: Send + Sync {
185    /// 返回 Provider 类型。
186    fn kind(&self) -> ProviderKind;
187    /// 返回默认基础地址。
188    fn default_base_url(&self) -> &str;
189    /// 返回认证方案。
190    fn auth_scheme(&self) -> AuthScheme;
191    /// 返回能力集合。
192    fn capabilities(&self) -> &'static CapabilitySet;
193    /// 在请求真正构建前对请求做进一步调整。
194    fn prepare_request(&self, ctx: &mut RequestContext) -> Result<()>;
195    /// 根据 Provider 规则适配错误。
196    fn adapt_error(&self, error: crate::ApiError) -> Error {
197        Error::Api(error)
198    }
199    /// 在发送前校验请求是否符合当前 Provider 要求。
200    fn validate_request(
201        &self,
202        endpoint_id: &'static str,
203        body: Option<&Value>,
204        mode: CompatibilityMode,
205    ) -> Result<()>;
206}
207
208/// 对外暴露的 Provider 句柄。
209#[derive(Clone)]
210pub struct Provider {
211    inner: Arc<dyn ProviderProfile>,
212}
213
214impl fmt::Debug for Provider {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        f.debug_struct("Provider")
217            .field("kind", &self.kind())
218            .field("default_base_url", &self.default_base_url())
219            .finish()
220    }
221}
222
223impl Provider {
224    /// 创建 OpenAI Provider。
225    pub fn openai() -> Self {
226        Self {
227            inner: Arc::new(OpenAiProfile),
228        }
229    }
230
231    /// 创建 Azure Provider。
232    pub fn azure() -> Self {
233        Self::azure_with_options(AzureOptions::default())
234    }
235
236    /// 创建带自定义选项的 Azure Provider。
237    pub fn azure_with_options(options: AzureOptions) -> Self {
238        Self {
239            inner: Arc::new(AzureProfile::new(options)),
240        }
241    }
242
243    /// 创建智谱 Provider。
244    pub fn zhipu() -> Self {
245        Self {
246            inner: Arc::new(ZhipuProfile),
247        }
248    }
249
250    /// 创建 MiniMax Provider。
251    pub fn minimax() -> Self {
252        Self {
253            inner: Arc::new(MiniMaxProfile),
254        }
255    }
256
257    /// 创建 ZenMux Provider。
258    pub fn zenmux() -> Self {
259        Self {
260            inner: Arc::new(ZenMuxProfile),
261        }
262    }
263
264    /// 创建自定义 Provider。
265    pub fn custom<T>(profile: T) -> Self
266    where
267        T: ProviderProfile + 'static,
268    {
269        Self {
270            inner: Arc::new(profile),
271        }
272    }
273
274    /// 返回 Provider 类型。
275    pub fn kind(&self) -> ProviderKind {
276        self.inner.kind()
277    }
278
279    /// 返回默认基础地址。
280    pub fn default_base_url(&self) -> &str {
281        self.inner.default_base_url()
282    }
283
284    /// 返回 ProviderProfile 引用。
285    pub fn profile(&self) -> &(dyn ProviderProfile + Send + Sync) {
286        self.inner.as_ref()
287    }
288}
289
290/// 表示自定义 Provider 实现。
291#[derive(Debug, Clone)]
292pub struct CustomProfile {
293    /// Provider 的自定义名称。
294    pub name: String,
295    /// 默认基础地址。
296    pub base_url: String,
297    /// 认证方案。
298    pub auth_scheme: AuthScheme,
299    /// 能力集合。
300    pub capabilities: CapabilitySet,
301}
302
303impl ProviderProfile for CustomProfile {
304    fn kind(&self) -> ProviderKind {
305        ProviderKind::Custom
306    }
307
308    fn default_base_url(&self) -> &str {
309        &self.base_url
310    }
311
312    fn auth_scheme(&self) -> AuthScheme {
313        self.auth_scheme
314    }
315
316    fn capabilities(&self) -> &'static CapabilitySet {
317        Box::leak(Box::new(self.capabilities))
318    }
319
320    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
321        Ok(())
322    }
323
324    fn validate_request(
325        &self,
326        _endpoint_id: &'static str,
327        _body: Option<&Value>,
328        _mode: CompatibilityMode,
329    ) -> Result<()> {
330        Ok(())
331    }
332}
333
334#[derive(Debug, Clone, Default)]
335struct AzureProfile {
336    options: AzureOptions,
337}
338
339impl AzureProfile {
340    fn new(options: AzureOptions) -> Self {
341        Self { options }
342    }
343
344    fn api_version(&self) -> &str {
345        self.options
346            .api_version
347            .as_deref()
348            .filter(|value| !value.trim().is_empty())
349            .unwrap_or("2025-03-01-preview")
350    }
351
352    fn auth_scheme(&self) -> AuthScheme {
353        self.options.auth_mode.auth_scheme()
354    }
355
356    fn deployment_for(&self, ctx: &RequestContext) -> Option<String> {
357        if ctx.endpoint_id == "realtime.ws.connect" {
358            return ctx
359                .query
360                .get("deployment")
361                .cloned()
362                .or_else(|| self.options.deployment.clone())
363                .filter(|value| !value.trim().is_empty());
364        }
365
366        if !azure_deployment_path_required(&ctx.path) {
367            return None;
368        }
369
370        self.options
371            .deployment
372            .clone()
373            .or_else(|| {
374                ctx.body
375                    .as_ref()
376                    .and_then(|value| value.get("model"))
377                    .and_then(Value::as_str)
378                    .map(str::to_owned)
379            })
380            .filter(|value| !value.trim().is_empty())
381    }
382}
383
384#[derive(Debug, Clone, Copy)]
385struct OpenAiProfile;
386
387#[derive(Debug, Clone, Copy)]
388struct ZhipuProfile;
389
390#[derive(Debug, Clone, Copy)]
391struct MiniMaxProfile;
392
393#[derive(Debug, Clone, Copy)]
394struct ZenMuxProfile;
395
396impl ProviderProfile for OpenAiProfile {
397    fn kind(&self) -> ProviderKind {
398        ProviderKind::OpenAI
399    }
400
401    fn default_base_url(&self) -> &str {
402        "https://api.openai.com/v1"
403    }
404
405    fn auth_scheme(&self) -> AuthScheme {
406        AuthScheme::Bearer
407    }
408
409    fn capabilities(&self) -> &'static CapabilitySet {
410        &FULL_CAPABILITIES
411    }
412
413    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
414        Ok(())
415    }
416
417    fn validate_request(
418        &self,
419        _endpoint_id: &'static str,
420        _body: Option<&Value>,
421        _mode: CompatibilityMode,
422    ) -> Result<()> {
423        Ok(())
424    }
425}
426
427impl ProviderProfile for AzureProfile {
428    fn kind(&self) -> ProviderKind {
429        ProviderKind::Azure
430    }
431
432    fn default_base_url(&self) -> &str {
433        "https://example-resource.openai.azure.com"
434    }
435
436    fn auth_scheme(&self) -> AuthScheme {
437        self.auth_scheme()
438    }
439
440    fn capabilities(&self) -> &'static CapabilitySet {
441        &FULL_CAPABILITIES
442    }
443
444    fn prepare_request(&self, ctx: &mut RequestContext) -> Result<()> {
445        ctx.query
446            .entry("api-version".into())
447            .or_insert_with(|| self.api_version().into());
448
449        if !ctx.path.starts_with("/openai") {
450            ctx.path = format!("/openai{}", ctx.path);
451        }
452
453        if let Some(deployment) = self.deployment_for(ctx)
454            && ctx.endpoint_id == "realtime.ws.connect"
455        {
456            ctx.query.insert("deployment".into(), deployment);
457        } else if let Some(deployment) = self.deployment_for(ctx)
458            && !ctx.path.contains("/deployments/")
459        {
460            ctx.path =
461                ctx.path
462                    .replacen("/openai/", &format!("/openai/deployments/{deployment}/"), 1);
463        }
464
465        Ok(())
466    }
467
468    fn validate_request(
469        &self,
470        _endpoint_id: &'static str,
471        _body: Option<&Value>,
472        _mode: CompatibilityMode,
473    ) -> Result<()> {
474        Ok(())
475    }
476}
477
478impl ProviderProfile for ZhipuProfile {
479    fn kind(&self) -> ProviderKind {
480        ProviderKind::Zhipu
481    }
482
483    fn default_base_url(&self) -> &str {
484        "https://open.bigmodel.cn/api/paas/v4"
485    }
486
487    fn auth_scheme(&self) -> AuthScheme {
488        AuthScheme::Bearer
489    }
490
491    fn capabilities(&self) -> &'static CapabilitySet {
492        &CHAT_ONLY_CAPABILITIES
493    }
494
495    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
496        Ok(())
497    }
498
499    fn validate_request(
500        &self,
501        _endpoint_id: &'static str,
502        _body: Option<&Value>,
503        _mode: CompatibilityMode,
504    ) -> Result<()> {
505        Ok(())
506    }
507}
508
509impl ProviderProfile for MiniMaxProfile {
510    fn kind(&self) -> ProviderKind {
511        ProviderKind::MiniMax
512    }
513
514    fn default_base_url(&self) -> &str {
515        "https://api.minimaxi.com/v1"
516    }
517
518    fn auth_scheme(&self) -> AuthScheme {
519        AuthScheme::Bearer
520    }
521
522    fn capabilities(&self) -> &'static CapabilitySet {
523        &CHAT_ONLY_CAPABILITIES
524    }
525
526    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
527        Ok(())
528    }
529
530    fn validate_request(
531        &self,
532        _endpoint_id: &'static str,
533        body: Option<&Value>,
534        mode: CompatibilityMode,
535    ) -> Result<()> {
536        if mode != CompatibilityMode::Strict {
537            return Ok(());
538        }
539
540        let Some(body) = body else {
541            return Ok(());
542        };
543
544        if let Some(value) = body.get("n").and_then(Value::as_i64)
545            && value != 1
546        {
547            return Err(ProviderCompatibilityError::new(
548                ProviderKind::MiniMax,
549                "MiniMax 在严格模式下仅支持 n = 1",
550            )
551            .into());
552        }
553
554        if contains_key(body, "function_call") {
555            return Err(ProviderCompatibilityError::new(
556                ProviderKind::MiniMax,
557                "MiniMax 在严格模式下不再支持旧版 function_call 字段,请改用 tools",
558            )
559            .into());
560        }
561
562        if contains_any_type(body, &["input_image", "image", "input_audio", "audio"]) {
563            return Err(ProviderCompatibilityError::new(
564                ProviderKind::MiniMax,
565                "MiniMax 在严格模式下不支持图像或音频输入",
566            )
567            .into());
568        }
569
570        Ok(())
571    }
572}
573
574impl ProviderProfile for ZenMuxProfile {
575    fn kind(&self) -> ProviderKind {
576        ProviderKind::ZenMux
577    }
578
579    fn default_base_url(&self) -> &str {
580        "https://zenmux.ai/api/v1"
581    }
582
583    fn auth_scheme(&self) -> AuthScheme {
584        AuthScheme::Bearer
585    }
586
587    fn capabilities(&self) -> &'static CapabilitySet {
588        &FULL_CAPABILITIES
589    }
590
591    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
592        Ok(())
593    }
594
595    fn validate_request(
596        &self,
597        _endpoint_id: &'static str,
598        body: Option<&Value>,
599        mode: CompatibilityMode,
600    ) -> Result<()> {
601        if mode != CompatibilityMode::Strict {
602            return Ok(());
603        }
604
605        let Some(model) = body
606            .and_then(|value| value.get("model"))
607            .and_then(Value::as_str)
608        else {
609            return Ok(());
610        };
611
612        if !model.contains('/') || model.starts_with('/') || model.ends_with('/') {
613            return Err(ProviderCompatibilityError::new(
614                ProviderKind::ZenMux,
615                "ZenMux 在严格模式下要求 model 采用 <provider>/<model_name> 形式",
616            )
617            .into());
618        }
619
620        Ok(())
621    }
622}
623
624fn contains_key(value: &Value, target: &str) -> bool {
625    match value {
626        Value::Object(map) => {
627            map.contains_key(target) || map.values().any(|value| contains_key(value, target))
628        }
629        Value::Array(values) => values.iter().any(|value| contains_key(value, target)),
630        _ => false,
631    }
632}
633
634fn contains_any_type(value: &Value, targets: &[&str]) -> bool {
635    match value {
636        Value::Object(map) => map.iter().any(|(key, nested)| {
637            (key == "type"
638                && nested
639                    .as_str()
640                    .is_some_and(|value| targets.contains(&value)))
641                || contains_any_type(nested, targets)
642        }),
643        Value::Array(values) => values.iter().any(|value| contains_any_type(value, targets)),
644        _ => false,
645    }
646}
647
648fn azure_deployment_path_required(path: &str) -> bool {
649    matches!(
650        path.trim_end_matches('/'),
651        "/completions"
652            | "/chat/completions"
653            | "/embeddings"
654            | "/audio/transcriptions"
655            | "/audio/translations"
656            | "/audio/speech"
657            | "/images/generations"
658            | "/images/edits"
659            | "/batches"
660            | "/openai/completions"
661            | "/openai/chat/completions"
662            | "/openai/embeddings"
663            | "/openai/audio/transcriptions"
664            | "/openai/audio/translations"
665            | "/openai/audio/speech"
666            | "/openai/images/generations"
667            | "/openai/images/edits"
668            | "/openai/batches"
669    )
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    #[test]
677    fn test_should_use_zhipu_default_base_url() {
678        let provider = Provider::zhipu();
679        assert_eq!(
680            provider.default_base_url(),
681            "https://open.bigmodel.cn/api/paas/v4"
682        );
683    }
684
685    #[test]
686    fn test_should_use_minimax_default_base_url() {
687        let provider = Provider::minimax();
688        assert_eq!(provider.default_base_url(), "https://api.minimaxi.com/v1");
689    }
690
691    #[test]
692    fn test_should_use_zenmux_default_base_url() {
693        let provider = Provider::zenmux();
694        assert_eq!(provider.default_base_url(), "https://zenmux.ai/api/v1");
695    }
696
697    #[test]
698    fn test_should_validate_minimax_n_equals_one_in_strict_mode() {
699        let provider = Provider::minimax();
700        let body = serde_json::json!({
701            "model": "MiniMax-M2.7",
702            "messages": [{"role": "user", "content": "hello"}],
703            "n": 2
704        });
705        let error = provider
706            .profile()
707            .validate_request(
708                "chat.completions.create",
709                Some(&body),
710                CompatibilityMode::Strict,
711            )
712            .unwrap_err();
713        assert!(matches!(error, Error::ProviderCompatibility(_)));
714    }
715
716    #[test]
717    fn test_should_validate_zenmux_model_id_format_in_strict_mode() {
718        let provider = Provider::zenmux();
719        let body = serde_json::json!({
720            "model": "gpt-5",
721            "input": "hello"
722        });
723        let error = provider
724            .profile()
725            .validate_request("responses.create", Some(&body), CompatibilityMode::Strict)
726            .unwrap_err();
727        assert!(matches!(error, Error::ProviderCompatibility(_)));
728    }
729
730    #[test]
731    fn test_should_preserve_passthrough_mode_for_minimax() {
732        let provider = Provider::minimax();
733        let body = serde_json::json!({
734            "model": "MiniMax-M2.7",
735            "messages": [{"role": "user", "content": "hello"}],
736            "n": 3
737        });
738        provider
739            .profile()
740            .validate_request(
741                "chat.completions.create",
742                Some(&body),
743                CompatibilityMode::Passthrough,
744            )
745            .unwrap();
746    }
747
748    #[test]
749    fn test_should_inject_azure_api_version_and_prefix_path() {
750        let provider =
751            Provider::azure_with_options(AzureOptions::new().api_version("2024-02-15-preview"));
752        let mut context = RequestContext {
753            endpoint_id: "responses.create",
754            path: "/responses".into(),
755            query: BTreeMap::new(),
756            headers: BTreeMap::new(),
757            body: None,
758        };
759
760        provider.profile().prepare_request(&mut context).unwrap();
761
762        assert_eq!(context.path, "/openai/responses");
763        assert_eq!(
764            context.query.get("api-version").map(String::as_str),
765            Some("2024-02-15-preview")
766        );
767    }
768
769    #[test]
770    fn test_should_preserve_existing_azure_api_version_query() {
771        let provider = Provider::azure();
772        let mut context = RequestContext {
773            endpoint_id: "responses.create",
774            path: "/responses".into(),
775            query: BTreeMap::from([("api-version".into(), "custom-version".into())]),
776            headers: BTreeMap::new(),
777            body: None,
778        };
779
780        provider.profile().prepare_request(&mut context).unwrap();
781
782        assert_eq!(
783            context.query.get("api-version").map(String::as_str),
784            Some("custom-version")
785        );
786    }
787
788    #[test]
789    fn test_should_inject_azure_deployment_from_body_model() {
790        let provider = Provider::azure();
791        let mut context = RequestContext {
792            endpoint_id: "chat.completions.create",
793            path: "/chat/completions".into(),
794            query: BTreeMap::new(),
795            headers: BTreeMap::new(),
796            body: Some(
797                serde_json::json!({
798                    "model": "gpt-4o-deployment"
799                })
800                .into(),
801            ),
802        };
803
804        provider.profile().prepare_request(&mut context).unwrap();
805
806        assert_eq!(
807            context.path,
808            "/openai/deployments/gpt-4o-deployment/chat/completions"
809        );
810    }
811
812    #[test]
813    fn test_should_inject_azure_realtime_deployment_query() {
814        let provider =
815            Provider::azure_with_options(AzureOptions::new().deployment("rt-deployment"));
816        let mut context = RequestContext {
817            endpoint_id: "realtime.ws.connect",
818            path: "/realtime".into(),
819            query: BTreeMap::new(),
820            headers: BTreeMap::new(),
821            body: None,
822        };
823
824        provider.profile().prepare_request(&mut context).unwrap();
825
826        assert_eq!(context.path, "/openai/realtime");
827        assert_eq!(
828            context.query.get("deployment").map(String::as_str),
829            Some("rt-deployment")
830        );
831    }
832
833    #[test]
834    fn test_should_switch_azure_auth_scheme_to_bearer() {
835        let provider = Provider::azure_with_options(AzureOptions::new().bearer_auth());
836        assert_eq!(provider.profile().auth_scheme(), AuthScheme::Bearer);
837    }
838}