Skip to main content

openai_core/providers/
mod.rs

1//! Provider 兼容层。
2
3use std::collections::BTreeMap;
4use std::fmt;
5use std::sync::Arc;
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::error::{Error, ProviderCompatibilityError, Result};
11use crate::json_payload::JsonPayload;
12
13/// 表示支持的 Provider 类型。
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ProviderKind {
17    /// 官方 OpenAI Provider。
18    OpenAI,
19    /// Azure OpenAI Provider。
20    Azure,
21    /// Kimi / Moonshot 兼容 Provider。
22    Kimi,
23    /// DeepSeek 兼容 Provider。
24    DeepSeek,
25    /// 智谱兼容 Provider。
26    Zhipu,
27    /// MiniMax 兼容 Provider。
28    MiniMax,
29    /// ZenMux 兼容 Provider。
30    ZenMux,
31    /// 自定义 Provider。
32    Custom,
33}
34
35impl ProviderKind {
36    /// 返回 provider 对应的小写键。
37    pub fn as_key(&self) -> &'static str {
38        match self {
39            Self::OpenAI => "openai",
40            Self::Azure => "azure",
41            Self::Kimi => "kimi",
42            Self::DeepSeek => "deepseek",
43            Self::Zhipu => "zhipu",
44            Self::MiniMax => "minimax",
45            Self::ZenMux => "zenmux",
46            Self::Custom => "custom",
47        }
48    }
49}
50
51/// 表示 Provider 的认证方案。
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum AuthScheme {
54    /// 使用 `Authorization: Bearer <token>`。
55    Bearer,
56    /// 使用 `api-key: <token>`。
57    ApiKeyHeader,
58    /// 使用查询参数传递令牌。
59    QueryToken,
60    /// 使用 WebSocket 子协议传递令牌。
61    WebSocketSubprotocol,
62}
63
64/// 表示兼容性校验模式。
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum CompatibilityMode {
67    /// 尽可能透传未知字段。
68    Passthrough,
69    /// 对已知风险发出警告。
70    Warn,
71    /// 对已知不兼容字段直接报错。
72    Strict,
73}
74
75/// 表示 Azure OpenAI 的认证模式。
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
77#[serde(rename_all = "snake_case")]
78pub enum AzureAuthMode {
79    /// 使用 `api-key` 请求头。
80    #[default]
81    ApiKey,
82    /// 使用 `Authorization: Bearer <token>`。
83    Bearer,
84}
85
86impl AzureAuthMode {
87    /// 转换为底层通用认证方案。
88    pub fn auth_scheme(self) -> AuthScheme {
89        match self {
90            Self::ApiKey => AuthScheme::ApiKeyHeader,
91            Self::Bearer => AuthScheme::Bearer,
92        }
93    }
94}
95
96/// 表示 Azure Provider 的可配置选项。
97#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
98pub struct AzureOptions {
99    /// Azure OpenAI `api-version`。
100    pub api_version: Option<String>,
101    /// 默认 deployment 名称。
102    pub deployment: Option<String>,
103    /// Azure 认证模式。
104    #[serde(default)]
105    pub auth_mode: AzureAuthMode,
106}
107
108impl AzureOptions {
109    /// 创建默认 Azure 选项。
110    pub fn new() -> Self {
111        Self::default()
112    }
113
114    /// 设置 `api-version`。
115    pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
116        self.api_version = Some(api_version.into());
117        self
118    }
119
120    /// 设置默认 deployment。
121    pub fn deployment(mut self, deployment: impl Into<String>) -> Self {
122        self.deployment = Some(deployment.into());
123        self
124    }
125
126    /// 切换为 Bearer Token 认证。
127    pub fn bearer_auth(mut self) -> Self {
128        self.auth_mode = AzureAuthMode::Bearer;
129        self
130    }
131
132    /// 切换为 `api-key` 认证。
133    pub fn api_key_auth(mut self) -> Self {
134        self.auth_mode = AzureAuthMode::ApiKey;
135        self
136    }
137}
138
139/// 表示 Provider 的能力集合。
140#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141pub struct CapabilitySet {
142    /// 是否支持聊天补全。
143    pub chat_completions: bool,
144    /// 是否支持 Responses API。
145    pub responses: bool,
146    /// 是否支持模型列表。
147    pub models: bool,
148    /// 是否支持 SSE 流。
149    pub streaming: bool,
150    /// 是否支持工具调用。
151    pub tools: bool,
152    /// 是否支持 Webhook。
153    pub webhooks: bool,
154}
155
156const FULL_CAPABILITIES: CapabilitySet = CapabilitySet {
157    chat_completions: true,
158    responses: true,
159    models: true,
160    streaming: true,
161    tools: true,
162    webhooks: true,
163};
164
165const CHAT_ONLY_CAPABILITIES: CapabilitySet = CapabilitySet {
166    chat_completions: true,
167    responses: false,
168    models: true,
169    streaming: true,
170    tools: true,
171    webhooks: false,
172};
173
174/// 表示 Provider 在发送请求前可修改的上下文。
175#[derive(Debug, Clone)]
176pub struct RequestContext {
177    /// 逻辑端点 ID。
178    pub endpoint_id: &'static str,
179    /// HTTP 路径。
180    pub path: String,
181    /// 查询参数。
182    pub query: BTreeMap<String, String>,
183    /// 请求头。
184    pub headers: BTreeMap<String, String>,
185    /// JSON 请求体。
186    pub body: Option<JsonPayload>,
187}
188
189/// ProviderProfile 用于屏蔽不同兼容 Provider 的差异。
190pub trait ProviderProfile: Send + Sync {
191    /// 返回 Provider 类型。
192    fn kind(&self) -> ProviderKind;
193    /// 返回默认基础地址。
194    fn default_base_url(&self) -> &str;
195    /// 返回认证方案。
196    fn auth_scheme(&self) -> AuthScheme;
197    /// 返回能力集合。
198    fn capabilities(&self) -> &'static CapabilitySet;
199    /// 在请求真正构建前对请求做进一步调整。
200    fn prepare_request(&self, ctx: &mut RequestContext) -> Result<()>;
201    /// 根据 Provider 规则适配错误。
202    fn adapt_error(&self, error: crate::ApiError) -> Error {
203        Error::Api(error)
204    }
205    /// 在发送前校验请求是否符合当前 Provider 要求。
206    fn validate_request(
207        &self,
208        endpoint_id: &'static str,
209        body: Option<&Value>,
210        mode: CompatibilityMode,
211    ) -> Result<()>;
212}
213
214/// 对外暴露的 Provider 句柄。
215#[derive(Clone)]
216pub struct Provider {
217    inner: Arc<dyn ProviderProfile>,
218}
219
220impl fmt::Debug for Provider {
221    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222        f.debug_struct("Provider")
223            .field("kind", &self.kind())
224            .field("default_base_url", &self.default_base_url())
225            .finish()
226    }
227}
228
229impl Provider {
230    /// 创建 OpenAI Provider。
231    pub fn openai() -> Self {
232        Self {
233            inner: Arc::new(OpenAiProfile),
234        }
235    }
236
237    /// 创建 Azure Provider。
238    pub fn azure() -> Self {
239        Self::azure_with_options(AzureOptions::default())
240    }
241
242    /// 创建带自定义选项的 Azure Provider。
243    pub fn azure_with_options(options: AzureOptions) -> Self {
244        Self {
245            inner: Arc::new(AzureProfile::new(options)),
246        }
247    }
248
249    /// 创建 Kimi / Moonshot Provider。
250    pub fn kimi() -> Self {
251        Self {
252            inner: Arc::new(KimiProfile),
253        }
254    }
255
256    /// 创建 DeepSeek Provider。
257    pub fn deepseek() -> Self {
258        Self {
259            inner: Arc::new(DeepSeekProfile),
260        }
261    }
262
263    /// 创建智谱 Provider。
264    pub fn zhipu() -> Self {
265        Self {
266            inner: Arc::new(ZhipuProfile),
267        }
268    }
269
270    /// 创建 MiniMax Provider。
271    pub fn minimax() -> Self {
272        Self {
273            inner: Arc::new(MiniMaxProfile),
274        }
275    }
276
277    /// 创建 ZenMux Provider。
278    pub fn zenmux() -> Self {
279        Self {
280            inner: Arc::new(ZenMuxProfile),
281        }
282    }
283
284    /// 创建自定义 Provider。
285    pub fn custom<T>(profile: T) -> Self
286    where
287        T: ProviderProfile + 'static,
288    {
289        Self {
290            inner: Arc::new(profile),
291        }
292    }
293
294    /// 返回 Provider 类型。
295    pub fn kind(&self) -> ProviderKind {
296        self.inner.kind()
297    }
298
299    /// 返回默认基础地址。
300    pub fn default_base_url(&self) -> &str {
301        self.inner.default_base_url()
302    }
303
304    /// 返回 ProviderProfile 引用。
305    pub fn profile(&self) -> &(dyn ProviderProfile + Send + Sync) {
306        self.inner.as_ref()
307    }
308}
309
310/// 表示自定义 Provider 实现。
311#[derive(Debug, Clone)]
312pub struct CustomProfile {
313    /// Provider 的自定义名称。
314    pub name: String,
315    /// 默认基础地址。
316    pub base_url: String,
317    /// 认证方案。
318    pub auth_scheme: AuthScheme,
319    /// 能力集合。
320    pub capabilities: CapabilitySet,
321}
322
323impl ProviderProfile for CustomProfile {
324    fn kind(&self) -> ProviderKind {
325        ProviderKind::Custom
326    }
327
328    fn default_base_url(&self) -> &str {
329        &self.base_url
330    }
331
332    fn auth_scheme(&self) -> AuthScheme {
333        self.auth_scheme
334    }
335
336    fn capabilities(&self) -> &'static CapabilitySet {
337        Box::leak(Box::new(self.capabilities))
338    }
339
340    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
341        Ok(())
342    }
343
344    fn validate_request(
345        &self,
346        _endpoint_id: &'static str,
347        _body: Option<&Value>,
348        _mode: CompatibilityMode,
349    ) -> Result<()> {
350        Ok(())
351    }
352}
353
354#[derive(Debug, Clone, Default)]
355struct AzureProfile {
356    options: AzureOptions,
357}
358
359impl AzureProfile {
360    fn new(options: AzureOptions) -> Self {
361        Self { options }
362    }
363
364    fn api_version(&self) -> &str {
365        self.options
366            .api_version
367            .as_deref()
368            .filter(|value| !value.trim().is_empty())
369            .unwrap_or("2025-03-01-preview")
370    }
371
372    fn auth_scheme(&self) -> AuthScheme {
373        self.options.auth_mode.auth_scheme()
374    }
375
376    fn deployment_for(&self, ctx: &RequestContext) -> Option<String> {
377        if ctx.endpoint_id == "realtime.ws.connect" {
378            return ctx
379                .query
380                .get("deployment")
381                .cloned()
382                .or_else(|| self.options.deployment.clone())
383                .filter(|value| !value.trim().is_empty());
384        }
385
386        if !azure_deployment_path_required(&ctx.path) {
387            return None;
388        }
389
390        self.options
391            .deployment
392            .clone()
393            .or_else(|| {
394                ctx.body
395                    .as_ref()
396                    .and_then(|value| value.get("model"))
397                    .and_then(Value::as_str)
398                    .map(str::to_owned)
399            })
400            .filter(|value| !value.trim().is_empty())
401    }
402}
403
404#[derive(Debug, Clone, Copy)]
405struct OpenAiProfile;
406
407#[derive(Debug, Clone, Copy)]
408struct KimiProfile;
409
410#[derive(Debug, Clone, Copy)]
411struct DeepSeekProfile;
412
413#[derive(Debug, Clone, Copy)]
414struct ZhipuProfile;
415
416#[derive(Debug, Clone, Copy)]
417struct MiniMaxProfile;
418
419#[derive(Debug, Clone, Copy)]
420struct ZenMuxProfile;
421
422impl ProviderProfile for OpenAiProfile {
423    fn kind(&self) -> ProviderKind {
424        ProviderKind::OpenAI
425    }
426
427    fn default_base_url(&self) -> &str {
428        "https://api.openai.com/v1"
429    }
430
431    fn auth_scheme(&self) -> AuthScheme {
432        AuthScheme::Bearer
433    }
434
435    fn capabilities(&self) -> &'static CapabilitySet {
436        &FULL_CAPABILITIES
437    }
438
439    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
440        Ok(())
441    }
442
443    fn validate_request(
444        &self,
445        _endpoint_id: &'static str,
446        _body: Option<&Value>,
447        _mode: CompatibilityMode,
448    ) -> Result<()> {
449        Ok(())
450    }
451}
452
453impl ProviderProfile for AzureProfile {
454    fn kind(&self) -> ProviderKind {
455        ProviderKind::Azure
456    }
457
458    fn default_base_url(&self) -> &str {
459        "https://example-resource.openai.azure.com"
460    }
461
462    fn auth_scheme(&self) -> AuthScheme {
463        self.auth_scheme()
464    }
465
466    fn capabilities(&self) -> &'static CapabilitySet {
467        &FULL_CAPABILITIES
468    }
469
470    fn prepare_request(&self, ctx: &mut RequestContext) -> Result<()> {
471        ctx.query
472            .entry("api-version".into())
473            .or_insert_with(|| self.api_version().into());
474
475        if !ctx.path.starts_with("/openai") {
476            ctx.path = format!("/openai{}", ctx.path);
477        }
478
479        if let Some(deployment) = self.deployment_for(ctx)
480            && ctx.endpoint_id == "realtime.ws.connect"
481        {
482            ctx.query.insert("deployment".into(), deployment);
483        } else if let Some(deployment) = self.deployment_for(ctx)
484            && !ctx.path.contains("/deployments/")
485        {
486            ctx.path =
487                ctx.path
488                    .replacen("/openai/", &format!("/openai/deployments/{deployment}/"), 1);
489        }
490
491        Ok(())
492    }
493
494    fn validate_request(
495        &self,
496        _endpoint_id: &'static str,
497        _body: Option<&Value>,
498        _mode: CompatibilityMode,
499    ) -> Result<()> {
500        Ok(())
501    }
502}
503
504impl ProviderProfile for KimiProfile {
505    fn kind(&self) -> ProviderKind {
506        ProviderKind::Kimi
507    }
508
509    fn default_base_url(&self) -> &str {
510        "https://api.moonshot.ai/v1"
511    }
512
513    fn auth_scheme(&self) -> AuthScheme {
514        AuthScheme::Bearer
515    }
516
517    fn capabilities(&self) -> &'static CapabilitySet {
518        &CHAT_ONLY_CAPABILITIES
519    }
520
521    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
522        Ok(())
523    }
524
525    fn validate_request(
526        &self,
527        _endpoint_id: &'static str,
528        _body: Option<&Value>,
529        _mode: CompatibilityMode,
530    ) -> Result<()> {
531        Ok(())
532    }
533}
534
535impl ProviderProfile for DeepSeekProfile {
536    fn kind(&self) -> ProviderKind {
537        ProviderKind::DeepSeek
538    }
539
540    fn default_base_url(&self) -> &str {
541        "https://api.deepseek.com"
542    }
543
544    fn auth_scheme(&self) -> AuthScheme {
545        AuthScheme::Bearer
546    }
547
548    fn capabilities(&self) -> &'static CapabilitySet {
549        &CHAT_ONLY_CAPABILITIES
550    }
551
552    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
553        Ok(())
554    }
555
556    fn validate_request(
557        &self,
558        _endpoint_id: &'static str,
559        _body: Option<&Value>,
560        _mode: CompatibilityMode,
561    ) -> Result<()> {
562        Ok(())
563    }
564}
565
566impl ProviderProfile for ZhipuProfile {
567    fn kind(&self) -> ProviderKind {
568        ProviderKind::Zhipu
569    }
570
571    fn default_base_url(&self) -> &str {
572        "https://open.bigmodel.cn/api/paas/v4"
573    }
574
575    fn auth_scheme(&self) -> AuthScheme {
576        AuthScheme::Bearer
577    }
578
579    fn capabilities(&self) -> &'static CapabilitySet {
580        &CHAT_ONLY_CAPABILITIES
581    }
582
583    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
584        Ok(())
585    }
586
587    fn validate_request(
588        &self,
589        _endpoint_id: &'static str,
590        _body: Option<&Value>,
591        _mode: CompatibilityMode,
592    ) -> Result<()> {
593        Ok(())
594    }
595}
596
597impl ProviderProfile for MiniMaxProfile {
598    fn kind(&self) -> ProviderKind {
599        ProviderKind::MiniMax
600    }
601
602    fn default_base_url(&self) -> &str {
603        "https://api.minimaxi.com/v1"
604    }
605
606    fn auth_scheme(&self) -> AuthScheme {
607        AuthScheme::Bearer
608    }
609
610    fn capabilities(&self) -> &'static CapabilitySet {
611        &CHAT_ONLY_CAPABILITIES
612    }
613
614    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
615        Ok(())
616    }
617
618    fn validate_request(
619        &self,
620        _endpoint_id: &'static str,
621        body: Option<&Value>,
622        mode: CompatibilityMode,
623    ) -> Result<()> {
624        if mode != CompatibilityMode::Strict {
625            return Ok(());
626        }
627
628        let Some(body) = body else {
629            return Ok(());
630        };
631
632        if let Some(value) = body.get("n").and_then(Value::as_i64)
633            && value != 1
634        {
635            return Err(ProviderCompatibilityError::new(
636                ProviderKind::MiniMax,
637                "MiniMax 在严格模式下仅支持 n = 1",
638            )
639            .into());
640        }
641
642        if contains_key(body, "function_call") {
643            return Err(ProviderCompatibilityError::new(
644                ProviderKind::MiniMax,
645                "MiniMax 在严格模式下不再支持旧版 function_call 字段,请改用 tools",
646            )
647            .into());
648        }
649
650        if contains_any_type(body, &["input_image", "image", "input_audio", "audio"]) {
651            return Err(ProviderCompatibilityError::new(
652                ProviderKind::MiniMax,
653                "MiniMax 在严格模式下不支持图像或音频输入",
654            )
655            .into());
656        }
657
658        Ok(())
659    }
660}
661
662impl ProviderProfile for ZenMuxProfile {
663    fn kind(&self) -> ProviderKind {
664        ProviderKind::ZenMux
665    }
666
667    fn default_base_url(&self) -> &str {
668        "https://zenmux.ai/api/v1"
669    }
670
671    fn auth_scheme(&self) -> AuthScheme {
672        AuthScheme::Bearer
673    }
674
675    fn capabilities(&self) -> &'static CapabilitySet {
676        &FULL_CAPABILITIES
677    }
678
679    fn prepare_request(&self, _ctx: &mut RequestContext) -> Result<()> {
680        Ok(())
681    }
682
683    fn validate_request(
684        &self,
685        _endpoint_id: &'static str,
686        body: Option<&Value>,
687        mode: CompatibilityMode,
688    ) -> Result<()> {
689        if mode != CompatibilityMode::Strict {
690            return Ok(());
691        }
692
693        let Some(model) = body
694            .and_then(|value| value.get("model"))
695            .and_then(Value::as_str)
696        else {
697            return Ok(());
698        };
699
700        if !model.contains('/') || model.starts_with('/') || model.ends_with('/') {
701            return Err(ProviderCompatibilityError::new(
702                ProviderKind::ZenMux,
703                "ZenMux 在严格模式下要求 model 采用 <provider>/<model_name> 形式",
704            )
705            .into());
706        }
707
708        Ok(())
709    }
710}
711
712fn contains_key(value: &Value, target: &str) -> bool {
713    match value {
714        Value::Object(map) => {
715            map.contains_key(target) || map.values().any(|value| contains_key(value, target))
716        }
717        Value::Array(values) => values.iter().any(|value| contains_key(value, target)),
718        _ => false,
719    }
720}
721
722fn contains_any_type(value: &Value, targets: &[&str]) -> bool {
723    match value {
724        Value::Object(map) => map.iter().any(|(key, nested)| {
725            (key == "type"
726                && nested
727                    .as_str()
728                    .is_some_and(|value| targets.contains(&value)))
729                || contains_any_type(nested, targets)
730        }),
731        Value::Array(values) => values.iter().any(|value| contains_any_type(value, targets)),
732        _ => false,
733    }
734}
735
736fn azure_deployment_path_required(path: &str) -> bool {
737    matches!(
738        path.trim_end_matches('/'),
739        "/completions"
740            | "/chat/completions"
741            | "/embeddings"
742            | "/audio/transcriptions"
743            | "/audio/translations"
744            | "/audio/speech"
745            | "/images/generations"
746            | "/images/edits"
747            | "/batches"
748            | "/openai/completions"
749            | "/openai/chat/completions"
750            | "/openai/embeddings"
751            | "/openai/audio/transcriptions"
752            | "/openai/audio/translations"
753            | "/openai/audio/speech"
754            | "/openai/images/generations"
755            | "/openai/images/edits"
756            | "/openai/batches"
757    )
758}
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763
764    #[test]
765    fn test_should_use_kimi_default_base_url() {
766        let provider = Provider::kimi();
767        assert_eq!(provider.default_base_url(), "https://api.moonshot.ai/v1");
768    }
769
770    #[test]
771    fn test_should_use_deepseek_default_base_url() {
772        let provider = Provider::deepseek();
773        assert_eq!(provider.default_base_url(), "https://api.deepseek.com");
774    }
775
776    #[test]
777    fn test_should_use_zhipu_default_base_url() {
778        let provider = Provider::zhipu();
779        assert_eq!(
780            provider.default_base_url(),
781            "https://open.bigmodel.cn/api/paas/v4"
782        );
783    }
784
785    #[test]
786    fn test_should_use_minimax_default_base_url() {
787        let provider = Provider::minimax();
788        assert_eq!(provider.default_base_url(), "https://api.minimaxi.com/v1");
789    }
790
791    #[test]
792    fn test_should_use_zenmux_default_base_url() {
793        let provider = Provider::zenmux();
794        assert_eq!(provider.default_base_url(), "https://zenmux.ai/api/v1");
795    }
796
797    #[test]
798    fn test_should_validate_minimax_n_equals_one_in_strict_mode() {
799        let provider = Provider::minimax();
800        let body = serde_json::json!({
801            "model": "MiniMax-M2.7",
802            "messages": [{"role": "user", "content": "hello"}],
803            "n": 2
804        });
805        let error = provider
806            .profile()
807            .validate_request(
808                "chat.completions.create",
809                Some(&body),
810                CompatibilityMode::Strict,
811            )
812            .unwrap_err();
813        assert!(matches!(error, Error::ProviderCompatibility(_)));
814    }
815
816    #[test]
817    fn test_should_validate_zenmux_model_id_format_in_strict_mode() {
818        let provider = Provider::zenmux();
819        let body = serde_json::json!({
820            "model": "gpt-5",
821            "input": "hello"
822        });
823        let error = provider
824            .profile()
825            .validate_request("responses.create", Some(&body), CompatibilityMode::Strict)
826            .unwrap_err();
827        assert!(matches!(error, Error::ProviderCompatibility(_)));
828    }
829
830    #[test]
831    fn test_should_preserve_passthrough_mode_for_minimax() {
832        let provider = Provider::minimax();
833        let body = serde_json::json!({
834            "model": "MiniMax-M2.7",
835            "messages": [{"role": "user", "content": "hello"}],
836            "n": 3
837        });
838        provider
839            .profile()
840            .validate_request(
841                "chat.completions.create",
842                Some(&body),
843                CompatibilityMode::Passthrough,
844            )
845            .unwrap();
846    }
847
848    #[test]
849    fn test_should_inject_azure_api_version_and_prefix_path() {
850        let provider =
851            Provider::azure_with_options(AzureOptions::new().api_version("2024-02-15-preview"));
852        let mut context = RequestContext {
853            endpoint_id: "responses.create",
854            path: "/responses".into(),
855            query: BTreeMap::new(),
856            headers: BTreeMap::new(),
857            body: None,
858        };
859
860        provider.profile().prepare_request(&mut context).unwrap();
861
862        assert_eq!(context.path, "/openai/responses");
863        assert_eq!(
864            context.query.get("api-version").map(String::as_str),
865            Some("2024-02-15-preview")
866        );
867    }
868
869    #[test]
870    fn test_should_preserve_existing_azure_api_version_query() {
871        let provider = Provider::azure();
872        let mut context = RequestContext {
873            endpoint_id: "responses.create",
874            path: "/responses".into(),
875            query: BTreeMap::from([("api-version".into(), "custom-version".into())]),
876            headers: BTreeMap::new(),
877            body: None,
878        };
879
880        provider.profile().prepare_request(&mut context).unwrap();
881
882        assert_eq!(
883            context.query.get("api-version").map(String::as_str),
884            Some("custom-version")
885        );
886    }
887
888    #[test]
889    fn test_should_inject_azure_deployment_from_body_model() {
890        let provider = Provider::azure();
891        let mut context = RequestContext {
892            endpoint_id: "chat.completions.create",
893            path: "/chat/completions".into(),
894            query: BTreeMap::new(),
895            headers: BTreeMap::new(),
896            body: Some(
897                serde_json::json!({
898                    "model": "gpt-4o-deployment"
899                })
900                .into(),
901            ),
902        };
903
904        provider.profile().prepare_request(&mut context).unwrap();
905
906        assert_eq!(
907            context.path,
908            "/openai/deployments/gpt-4o-deployment/chat/completions"
909        );
910    }
911
912    #[test]
913    fn test_should_inject_azure_realtime_deployment_query() {
914        let provider =
915            Provider::azure_with_options(AzureOptions::new().deployment("rt-deployment"));
916        let mut context = RequestContext {
917            endpoint_id: "realtime.ws.connect",
918            path: "/realtime".into(),
919            query: BTreeMap::new(),
920            headers: BTreeMap::new(),
921            body: None,
922        };
923
924        provider.profile().prepare_request(&mut context).unwrap();
925
926        assert_eq!(context.path, "/openai/realtime");
927        assert_eq!(
928            context.query.get("deployment").map(String::as_str),
929            Some("rt-deployment")
930        );
931    }
932
933    #[test]
934    fn test_should_switch_azure_auth_scheme_to_bearer() {
935        let provider = Provider::azure_with_options(AzureOptions::new().bearer_auth());
936        assert_eq!(provider.profile().auth_scheme(), AuthScheme::Bearer);
937    }
938}