Skip to main content

openai_core/
client.rs

1//! 客户端入口与构建器。
2
3use std::collections::BTreeMap;
4use std::env;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8
9use secrecy::SecretString;
10use tracing::{debug, error, info, warn};
11
12use crate::auth::ApiKeySource;
13use crate::config::{ClientOptions, LogLevel, LogRecord, Logger, LoggerHandle};
14use crate::error::{Error, Result};
15use crate::pagination::{CursorPage, ListEnvelope};
16use crate::providers::{AzureOptions, CompatibilityMode, Provider, ProviderKind};
17use crate::resources::{
18    AdminResource, AudioResource, BatchesResource, BetaResource, ChatResource, CompletionsResource,
19    ContainersResource, ConversationsResource, EmbeddingsResource, EvalsResource, FilesResource,
20    FineTuningResource, GradersResource, ImagesResource, ModelsResource, ModerationsResource,
21    RealtimeResource, ResponsesResource, SkillsResource, UploadsResource, VectorStoresResource,
22    VideosResource, WebhooksResource,
23};
24use crate::transport::{
25    RequestSpec, execute_bytes, execute_json, execute_raw_http, execute_raw_sse, execute_sse,
26};
27use crate::{ApiResponse, RawSseStream, SseStream};
28
29/// `Client` 是对底层 HTTP 客户端的轻量封装。
30#[derive(Debug, Clone)]
31pub struct Client {
32    pub(crate) inner: Arc<ClientInner>,
33}
34
35/// 客户端内部共享状态。
36#[derive(Debug)]
37pub(crate) struct ClientInner {
38    pub(crate) http: reqwest::Client,
39    pub(crate) options: ClientOptions,
40    pub(crate) api_key_source: Option<ApiKeySource>,
41    pub(crate) admin_api_key_source: Option<ApiKeySource>,
42    pub(crate) provider: Provider,
43}
44
45/// 表示分页下一页请求所需的元信息。
46#[derive(Debug, Clone)]
47pub struct PageRequestSpec {
48    /// 发起下一页请求的客户端。
49    pub client: Client,
50    /// 端点 ID。
51    pub endpoint_id: &'static str,
52    /// HTTP 方法。
53    pub method: http::Method,
54    /// 请求路径。
55    pub path: String,
56    /// 查询参数。
57    pub query: BTreeMap<String, Option<String>>,
58}
59
60/// `Client` 的构建器。
61#[derive(Debug, Clone, Default)]
62pub struct ClientBuilder {
63    options: ClientOptions,
64    api_key_source: Option<ApiKeySource>,
65    admin_api_key_source: Option<ApiKeySource>,
66    azure_options: AzureOptions,
67    azure_endpoint: Option<String>,
68    azure_configured: bool,
69    http_client: Option<reqwest::Client>,
70}
71
72impl Client {
73    /// 创建客户端构建器。
74    pub fn builder() -> ClientBuilder {
75        ClientBuilder::from_env()
76    }
77
78    /// 返回当前客户端的 Provider。
79    pub fn provider(&self) -> &Provider {
80        &self.inner.provider
81    }
82
83    /// 返回当前客户端的基础地址。
84    pub fn base_url(&self) -> &str {
85        self.inner.base_url()
86    }
87
88    /// 使用闭包覆盖一部分客户端选项,并返回新客户端。
89    pub fn with_options<F>(&self, mutate: F) -> Self
90    where
91        F: FnOnce(&mut ClientOptions),
92    {
93        let mut options = self.inner.options.clone();
94        mutate(&mut options);
95        Self::from_parts(
96            self.inner.http.clone(),
97            options.provider.clone(),
98            self.inner.api_key_source.clone(),
99            self.inner.admin_api_key_source.clone(),
100            options,
101        )
102    }
103
104    pub(crate) fn from_parts(
105        http: reqwest::Client,
106        provider: Provider,
107        api_key_source: Option<ApiKeySource>,
108        admin_api_key_source: Option<ApiKeySource>,
109        mut options: ClientOptions,
110    ) -> Self {
111        options.provider = provider.clone();
112        Self {
113            inner: Arc::new(ClientInner {
114                http,
115                options,
116                api_key_source,
117                admin_api_key_source,
118                provider,
119            }),
120        }
121    }
122
123    pub(crate) async fn execute_json<T>(&self, spec: RequestSpec) -> Result<ApiResponse<T>>
124    where
125        T: serde::de::DeserializeOwned,
126    {
127        execute_json(&self.inner, spec).await
128    }
129
130    pub(crate) async fn execute_bytes(
131        &self,
132        spec: RequestSpec,
133    ) -> Result<ApiResponse<bytes::Bytes>> {
134        execute_bytes(&self.inner, spec).await
135    }
136
137    pub(crate) async fn execute_sse<T>(&self, spec: RequestSpec) -> Result<SseStream<T>>
138    where
139        T: serde::de::DeserializeOwned + Send + 'static,
140    {
141        execute_sse(&self.inner, spec).await
142    }
143
144    #[allow(dead_code)]
145    pub(crate) async fn execute_raw_sse(&self, spec: RequestSpec) -> Result<RawSseStream> {
146        execute_raw_sse(&self.inner, spec).await
147    }
148
149    pub(crate) async fn execute_raw_http(
150        &self,
151        spec: RequestSpec,
152    ) -> Result<http::Response<bytes::Bytes>> {
153        execute_raw_http(&self.inner, spec).await
154    }
155
156    pub(crate) async fn fetch_cursor_page<T>(&self, page: PageRequestSpec) -> Result<CursorPage<T>>
157    where
158        T: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
159    {
160        let method = page.method.clone();
161        let mut spec = RequestSpec::new(page.endpoint_id, method.clone(), page.path.clone());
162        spec.options.extra_query = page.query;
163
164        let response = self.execute_json::<ListEnvelope<T>>(spec).await?;
165        let ListEnvelope {
166            object,
167            data,
168            first_id,
169            last_id,
170            has_more,
171            extra,
172        } = response.data;
173        let next_query = last_id
174            .as_ref()
175            .map(|last_id| {
176                let mut query = BTreeMap::new();
177                query.insert("after".into(), Some(last_id.clone()));
178                query
179            })
180            .unwrap_or_default();
181        let page_value = CursorPage::from(ListEnvelope {
182            object,
183            data,
184            first_id,
185            last_id,
186            has_more,
187            extra,
188        });
189        Ok(page_value.with_next_request(if has_more {
190            Some(PageRequestSpec {
191                client: self.clone(),
192                endpoint_id: page.endpoint_id,
193                method,
194                path: page.path,
195                query: next_query,
196            })
197        } else {
198            None
199        }))
200    }
201
202    /// 返回顶层 completions 资源。
203    pub fn completions(&self) -> CompletionsResource {
204        CompletionsResource::new(self.clone())
205    }
206
207    /// 返回 admin 命名空间。
208    pub fn admin(&self) -> AdminResource {
209        AdminResource::new(self.clone())
210    }
211
212    /// 返回 chat 命名空间。
213    pub fn chat(&self) -> ChatResource {
214        ChatResource::new(self.clone())
215    }
216
217    /// 返回 embeddings 资源。
218    pub fn embeddings(&self) -> EmbeddingsResource {
219        EmbeddingsResource::new(self.clone())
220    }
221
222    /// 返回 files 资源。
223    pub fn files(&self) -> FilesResource {
224        FilesResource::new(self.clone())
225    }
226
227    /// 返回 images 资源。
228    pub fn images(&self) -> ImagesResource {
229        ImagesResource::new(self.clone())
230    }
231
232    /// 返回 audio 命名空间。
233    pub fn audio(&self) -> AudioResource {
234        AudioResource::new(self.clone())
235    }
236
237    /// 返回 moderations 资源。
238    pub fn moderations(&self) -> ModerationsResource {
239        ModerationsResource::new(self.clone())
240    }
241
242    /// 返回 models 资源。
243    pub fn models(&self) -> ModelsResource {
244        ModelsResource::new(self.clone())
245    }
246
247    /// 返回 fine_tuning 命名空间。
248    pub fn fine_tuning(&self) -> FineTuningResource {
249        FineTuningResource::new(self.clone())
250    }
251
252    /// 返回 graders 命名空间。
253    pub fn graders(&self) -> GradersResource {
254        GradersResource::new(self.clone())
255    }
256
257    /// 返回 vector_stores 资源。
258    pub fn vector_stores(&self) -> VectorStoresResource {
259        VectorStoresResource::new(self.clone())
260    }
261
262    /// 返回 webhooks 资源。
263    pub fn webhooks(&self) -> WebhooksResource {
264        WebhooksResource::new(self.clone())
265    }
266
267    /// 返回 batches 资源。
268    pub fn batches(&self) -> BatchesResource {
269        BatchesResource::new(self.clone())
270    }
271
272    /// 返回 uploads 资源。
273    pub fn uploads(&self) -> UploadsResource {
274        UploadsResource::new(self.clone())
275    }
276
277    /// 返回 responses 资源。
278    pub fn responses(&self) -> ResponsesResource {
279        ResponsesResource::new(self.clone())
280    }
281
282    /// 返回 realtime 资源。
283    pub fn realtime(&self) -> RealtimeResource {
284        RealtimeResource::new(self.clone())
285    }
286
287    /// 返回 conversations 资源。
288    pub fn conversations(&self) -> ConversationsResource {
289        ConversationsResource::new(self.clone())
290    }
291
292    /// 返回 evals 资源。
293    pub fn evals(&self) -> EvalsResource {
294        EvalsResource::new(self.clone())
295    }
296
297    /// 返回 containers 资源。
298    pub fn containers(&self) -> ContainersResource {
299        ContainersResource::new(self.clone())
300    }
301
302    /// 返回 skills 资源。
303    pub fn skills(&self) -> SkillsResource {
304        SkillsResource::new(self.clone())
305    }
306
307    /// 返回 videos 资源。
308    pub fn videos(&self) -> VideosResource {
309        VideosResource::new(self.clone())
310    }
311
312    /// 返回 beta 命名空间。
313    pub fn beta(&self) -> BetaResource {
314        BetaResource::new(self.clone())
315    }
316}
317
318impl ClientInner {
319    pub(crate) fn base_url(&self) -> &str {
320        self.options
321            .base_url
322            .as_deref()
323            .unwrap_or_else(|| self.provider.default_base_url())
324    }
325
326    pub(crate) fn log(
327        &self,
328        level: LogLevel,
329        target: &'static str,
330        message: impl Into<String>,
331        fields: BTreeMap<String, String>,
332    ) {
333        if !self.options.log_level.allows(level) {
334            return;
335        }
336
337        let record = LogRecord {
338            level,
339            target,
340            message: message.into(),
341            fields,
342        };
343
344        if let Some(logger) = &self.options.logger {
345            logger.log(&record);
346        }
347
348        let rendered_fields = if record.fields.is_empty() {
349            String::new()
350        } else {
351            format!(
352                " {}",
353                record
354                    .fields
355                    .iter()
356                    .map(|(key, value)| format!("{key}={value}"))
357                    .collect::<Vec<_>>()
358                    .join(" ")
359            )
360        };
361        let rendered = format!("[{}] {}{}", target, record.message, rendered_fields);
362        match level {
363            LogLevel::Off => {}
364            LogLevel::Error => error!("{rendered}"),
365            LogLevel::Warn => warn!("{rendered}"),
366            LogLevel::Info => info!("{rendered}"),
367            LogLevel::Debug => debug!("{rendered}"),
368        }
369    }
370}
371
372impl ClientBuilder {
373    /// 从环境变量构建默认配置。
374    pub fn from_env() -> Self {
375        let mut builder = Self::default();
376
377        if let Some(webhook_secret) = read_env("OPENAI_WEBHOOK_SECRET") {
378            builder.options.webhook_secret = Some(SecretString::new(webhook_secret.into()));
379        }
380        if let Some(log_level) =
381            read_env("OPENAI_LOG").and_then(|value| value.parse::<LogLevel>().ok())
382        {
383            builder.options.log_level = log_level;
384        }
385
386        if let Some(azure_endpoint) = read_env("AZURE_OPENAI_ENDPOINT") {
387            builder = builder.azure_endpoint(azure_endpoint);
388            if let Some(api_version) = read_env("OPENAI_API_VERSION") {
389                builder = builder.azure_api_version(api_version);
390            }
391            if let Some(api_key) = read_env("AZURE_OPENAI_API_KEY") {
392                builder = builder.api_key(api_key);
393            }
394            return builder;
395        }
396
397        if let Some(base_url) = read_env("OPENAI_BASE_URL") {
398            builder.options.base_url = Some(base_url);
399        }
400        if let Some(api_key) = read_env("OPENAI_API_KEY") {
401            builder.api_key_source = Some(ApiKeySource::from_static(api_key));
402        }
403        if let Some(admin_api_key) = read_env("OPENAI_ADMIN_KEY") {
404            builder.admin_api_key_source = Some(ApiKeySource::from_static(admin_api_key));
405        }
406
407        builder
408    }
409
410    /// 设置 Provider。
411    pub fn provider(mut self, provider: Provider) -> Self {
412        if provider.kind() != ProviderKind::Azure {
413            self.azure_options = AzureOptions::default();
414            self.azure_endpoint = None;
415            self.azure_configured = false;
416        }
417        self.options.provider = provider;
418        self
419    }
420
421    /// 注入一个自定义 `reqwest::Client`。
422    pub fn http_client(mut self, client: reqwest::Client) -> Self {
423        self.http_client = Some(client);
424        self
425    }
426
427    /// 设置 SDK 内部日志级别。
428    pub fn log_level(mut self, log_level: LogLevel) -> Self {
429        self.options.log_level = log_level;
430        self
431    }
432
433    /// 注入一个用户自定义日志器。
434    pub fn logger<L>(mut self, logger: L) -> Self
435    where
436        L: Logger + 'static,
437    {
438        self.options.logger = Some(LoggerHandle::new(logger));
439        self
440    }
441
442    /// 设置静态 API Key。
443    pub fn api_key<T>(mut self, api_key: T) -> Self
444    where
445        T: Into<String>,
446    {
447        self.api_key_source = Some(ApiKeySource::from_static(api_key));
448        self
449    }
450
451    /// 设置动态 API Key 回调。
452    pub fn api_key_provider<F>(mut self, provider: F) -> Self
453    where
454        F: Fn() -> Result<SecretString> + Send + Sync + 'static,
455    {
456        self.api_key_source = Some(ApiKeySource::from_provider(provider));
457        self
458    }
459
460    /// 设置异步 API Key 回调。
461    pub fn api_key_async_provider<F, Fut>(mut self, provider: F) -> Self
462    where
463        F: Fn() -> Fut + Send + Sync + 'static,
464        Fut: Future<Output = Result<SecretString>> + Send + 'static,
465    {
466        self.api_key_source = Some(ApiKeySource::from_async_provider(provider));
467        self
468    }
469
470    /// 设置静态 Admin API Key。
471    pub fn admin_api_key<T>(mut self, admin_api_key: T) -> Self
472    where
473        T: Into<String>,
474    {
475        self.admin_api_key_source = Some(ApiKeySource::from_static(admin_api_key));
476        self
477    }
478
479    /// 设置动态 Admin API Key 回调。
480    pub fn admin_api_key_provider<F>(mut self, provider: F) -> Self
481    where
482        F: Fn() -> Result<SecretString> + Send + Sync + 'static,
483    {
484        self.admin_api_key_source = Some(ApiKeySource::from_provider(provider));
485        self
486    }
487
488    /// 设置异步 Admin API Key 回调。
489    pub fn admin_api_key_async_provider<F, Fut>(mut self, provider: F) -> Self
490    where
491        F: Fn() -> Fut + Send + Sync + 'static,
492        Fut: Future<Output = Result<SecretString>> + Send + 'static,
493    {
494        self.admin_api_key_source = Some(ApiKeySource::from_async_provider(provider));
495        self
496    }
497
498    /// 覆盖基础地址。
499    pub fn base_url<T>(mut self, base_url: T) -> Self
500    where
501        T: Into<String>,
502    {
503        self.options.base_url = Some(base_url.into());
504        self
505    }
506
507    /// 控制当 `base_url` 指向本机地址时,是否显式关闭系统代理。
508    ///
509    /// 该开关默认关闭。
510    pub fn disable_proxy_for_local_base_url(mut self, disable: bool) -> Self {
511        self.options.disable_proxy_for_local_base_url = disable;
512        self
513    }
514
515    /// 设置 Azure 资源级 endpoint。
516    ///
517    /// 该值应类似 `https://example-resource.openai.azure.com`,
518    /// SDK 会在发送请求时自动补上 `/openai`。
519    pub fn azure_endpoint<T>(mut self, endpoint: T) -> Self
520    where
521        T: Into<String>,
522    {
523        self.azure_endpoint = Some(endpoint.into());
524        self.azure_configured = true;
525        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
526        self
527    }
528
529    /// 设置 Azure `api-version`。
530    pub fn azure_api_version<T>(mut self, api_version: T) -> Self
531    where
532        T: Into<String>,
533    {
534        self.azure_options.api_version = Some(api_version.into());
535        self.azure_configured = true;
536        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
537        self
538    }
539
540    /// 设置 Azure 默认 deployment。
541    pub fn azure_deployment<T>(mut self, deployment: T) -> Self
542    where
543        T: Into<String>,
544    {
545        self.azure_options.deployment = Some(deployment.into());
546        self.azure_configured = true;
547        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
548        self
549    }
550
551    /// 切换 Azure 为 Bearer Token 认证。
552    pub fn azure_bearer_auth(mut self) -> Self {
553        self.azure_options = self.azure_options.bearer_auth();
554        self.azure_configured = true;
555        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
556        self
557    }
558
559    /// 设置 Azure AD Bearer Token。
560    pub fn azure_ad_token<T>(mut self, token: T) -> Self
561    where
562        T: Into<String>,
563    {
564        self.azure_options = self.azure_options.bearer_auth();
565        self.azure_configured = true;
566        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
567        self.api_key_source = Some(ApiKeySource::from_static(token));
568        self
569    }
570
571    /// 设置 Azure AD Bearer Token 异步提供器。
572    pub fn azure_ad_token_provider<F, Fut>(mut self, provider: F) -> Self
573    where
574        F: Fn() -> Fut + Send + Sync + 'static,
575        Fut: Future<Output = Result<SecretString>> + Send + 'static,
576    {
577        self.azure_options = self.azure_options.bearer_auth();
578        self.azure_configured = true;
579        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
580        self.api_key_source = Some(ApiKeySource::from_async_provider(provider));
581        self
582    }
583
584    /// 覆盖默认超时时间。
585    pub fn timeout(mut self, timeout: Duration) -> Self {
586        self.options.timeout = timeout;
587        self
588    }
589
590    /// 覆盖默认最大重试次数。
591    pub fn max_retries(mut self, max_retries: u32) -> Self {
592        self.options.max_retries = max_retries;
593        self
594    }
595
596    /// 添加默认请求头。
597    pub fn default_header<T, U>(mut self, key: T, value: U) -> Self
598    where
599        T: Into<String>,
600        U: Into<String>,
601    {
602        self.options
603            .default_headers
604            .insert(key.into(), value.into());
605        self
606    }
607
608    /// 批量设置默认请求头。
609    pub fn default_headers(mut self, headers: BTreeMap<String, String>) -> Self {
610        self.options.default_headers = headers;
611        self
612    }
613
614    /// 添加默认查询参数。
615    pub fn default_query<T, U>(mut self, key: T, value: U) -> Self
616    where
617        T: Into<String>,
618        U: Into<String>,
619    {
620        self.options.default_query.insert(key.into(), value.into());
621        self
622    }
623
624    /// 批量设置默认查询参数。
625    pub fn default_query_map(mut self, query: BTreeMap<String, String>) -> Self {
626        self.options.default_query = query;
627        self
628    }
629
630    /// 设置 Webhook 密钥。
631    pub fn webhook_secret<T>(mut self, secret: T) -> Self
632    where
633        T: Into<String>,
634    {
635        self.options.webhook_secret = Some(SecretString::new(secret.into().into()));
636        self
637    }
638
639    /// 设置兼容性模式。
640    pub fn compatibility_mode(mut self, mode: CompatibilityMode) -> Self {
641        self.options.compatibility_mode = mode;
642        self
643    }
644
645    /// 构建客户端。
646    ///
647    /// # Errors
648    ///
649    /// 当基础地址非法或底层 `reqwest::Client` 初始化失败时返回错误。
650    pub fn build(self) -> Result<Client> {
651        let mut options = self.options;
652        if options.provider.kind() == ProviderKind::Azure
653            && (self.azure_configured || self.azure_endpoint.is_some())
654        {
655            options.provider = Provider::azure_with_options(self.azure_options.clone());
656            if let Some(endpoint) = self.azure_endpoint {
657                if options.base_url.is_some() {
658                    return Err(Error::InvalidConfig(
659                        "`base_url` 和 `azure_endpoint` 不能同时设置".into(),
660                    ));
661                }
662                options.base_url = Some(endpoint);
663            }
664        }
665
666        let http = if let Some(client) = self.http_client {
667            client
668        } else {
669            let mut default_headers = reqwest::header::HeaderMap::new();
670            default_headers.insert(
671                reqwest::header::USER_AGENT,
672                reqwest::header::HeaderValue::from_static(concat!(
673                    env!("CARGO_PKG_NAME"),
674                    "/",
675                    env!("CARGO_PKG_VERSION")
676                )),
677            );
678
679            let mut builder = reqwest::Client::builder().default_headers(default_headers);
680            if options.disable_proxy_for_local_base_url
681                && should_disable_proxy_for_base_url(options.base_url.as_deref())
682            {
683                builder = builder.no_proxy();
684            }
685
686            builder
687                .build()
688                .map_err(|error| Error::InvalidConfig(format!("创建 HTTP 客户端失败: {error}")))?
689        };
690        Ok(Client::from_parts(
691            http,
692            options.provider.clone(),
693            self.api_key_source,
694            self.admin_api_key_source,
695            options,
696        ))
697    }
698}
699
700fn read_env(key: &str) -> Option<String> {
701    env::var(key)
702        .ok()
703        .map(|value| value.trim().to_owned())
704        .filter(|value| !value.is_empty())
705}
706
707fn should_disable_proxy_for_base_url(base_url: Option<&str>) -> bool {
708    let Some(base_url) = base_url else {
709        return false;
710    };
711
712    let Ok(url) = url::Url::parse(base_url) else {
713        return false;
714    };
715
716    matches!(
717        url.host_str(),
718        Some("localhost") | Some("127.0.0.1") | Some("[::1]") | Some("::1")
719    )
720}