Skip to main content

openai_core/
client.rs

1//! 客户端入口与构建器。
2
3use std::collections::BTreeMap;
4use std::env;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8
9use secrecy::SecretString;
10use tracing::{debug, error, info, warn};
11
12use crate::auth::ApiKeySource;
13use crate::config::{ClientOptions, LogLevel, LogRecord, Logger, LoggerHandle};
14use crate::error::{Error, Result};
15use crate::pagination::{CursorPage, ListEnvelope};
16use crate::providers::{AzureOptions, CompatibilityMode, Provider, ProviderKind};
17use crate::resources::{
18    AudioResource, BatchesResource, BetaResource, ChatResource, CompletionsResource,
19    ContainersResource, ConversationsResource, EmbeddingsResource, EvalsResource, FilesResource,
20    FineTuningResource, GradersResource, ImagesResource, ModelsResource, ModerationsResource,
21    RealtimeResource, ResponsesResource, SkillsResource, UploadsResource, VectorStoresResource,
22    VideosResource, WebhooksResource,
23};
24use crate::transport::{
25    RequestSpec, execute_bytes, execute_json, execute_raw_http, execute_raw_sse, execute_sse,
26};
27use crate::{ApiResponse, RawSseStream, SseStream};
28
29/// `Client` 是对底层 HTTP 客户端的轻量封装。
30#[derive(Debug, Clone)]
31pub struct Client {
32    pub(crate) inner: Arc<ClientInner>,
33}
34
35/// 客户端内部共享状态。
36#[derive(Debug)]
37pub(crate) struct ClientInner {
38    pub(crate) http: reqwest::Client,
39    pub(crate) options: ClientOptions,
40    pub(crate) api_key_source: Option<ApiKeySource>,
41    pub(crate) provider: Provider,
42}
43
44/// 表示分页下一页请求所需的元信息。
45#[derive(Debug, Clone)]
46pub struct PageRequestSpec {
47    /// 发起下一页请求的客户端。
48    pub client: Client,
49    /// 端点 ID。
50    pub endpoint_id: &'static str,
51    /// HTTP 方法。
52    pub method: http::Method,
53    /// 请求路径。
54    pub path: String,
55    /// 查询参数。
56    pub query: BTreeMap<String, Option<String>>,
57}
58
59/// `Client` 的构建器。
60#[derive(Debug, Clone, Default)]
61pub struct ClientBuilder {
62    options: ClientOptions,
63    api_key_source: Option<ApiKeySource>,
64    azure_options: AzureOptions,
65    azure_endpoint: Option<String>,
66    azure_configured: bool,
67    http_client: Option<reqwest::Client>,
68}
69
70impl Client {
71    /// 创建客户端构建器。
72    pub fn builder() -> ClientBuilder {
73        ClientBuilder::from_env()
74    }
75
76    /// 返回当前客户端的 Provider。
77    pub fn provider(&self) -> &Provider {
78        &self.inner.provider
79    }
80
81    /// 返回当前客户端的基础地址。
82    pub fn base_url(&self) -> &str {
83        self.inner.base_url()
84    }
85
86    /// 使用闭包覆盖一部分客户端选项,并返回新客户端。
87    pub fn with_options<F>(&self, mutate: F) -> Self
88    where
89        F: FnOnce(&mut ClientOptions),
90    {
91        let mut options = self.inner.options.clone();
92        mutate(&mut options);
93        Self::from_parts(
94            self.inner.http.clone(),
95            options.provider.clone(),
96            self.inner.api_key_source.clone(),
97            options,
98        )
99    }
100
101    pub(crate) fn from_parts(
102        http: reqwest::Client,
103        provider: Provider,
104        api_key_source: Option<ApiKeySource>,
105        mut options: ClientOptions,
106    ) -> Self {
107        options.provider = provider.clone();
108        Self {
109            inner: Arc::new(ClientInner {
110                http,
111                options,
112                api_key_source,
113                provider,
114            }),
115        }
116    }
117
118    pub(crate) async fn execute_json<T>(&self, spec: RequestSpec) -> Result<ApiResponse<T>>
119    where
120        T: serde::de::DeserializeOwned,
121    {
122        execute_json(&self.inner, spec).await
123    }
124
125    pub(crate) async fn execute_bytes(
126        &self,
127        spec: RequestSpec,
128    ) -> Result<ApiResponse<bytes::Bytes>> {
129        execute_bytes(&self.inner, spec).await
130    }
131
132    pub(crate) async fn execute_sse<T>(&self, spec: RequestSpec) -> Result<SseStream<T>>
133    where
134        T: serde::de::DeserializeOwned + Send + 'static,
135    {
136        execute_sse(&self.inner, spec).await
137    }
138
139    #[allow(dead_code)]
140    pub(crate) async fn execute_raw_sse(&self, spec: RequestSpec) -> Result<RawSseStream> {
141        execute_raw_sse(&self.inner, spec).await
142    }
143
144    pub(crate) async fn execute_raw_http(
145        &self,
146        spec: RequestSpec,
147    ) -> Result<http::Response<bytes::Bytes>> {
148        execute_raw_http(&self.inner, spec).await
149    }
150
151    pub(crate) async fn fetch_cursor_page<T>(&self, page: PageRequestSpec) -> Result<CursorPage<T>>
152    where
153        T: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
154    {
155        let method = page.method.clone();
156        let mut spec = RequestSpec::new(page.endpoint_id, method.clone(), page.path.clone());
157        spec.options.extra_query = page.query;
158
159        let response = self.execute_json::<ListEnvelope<T>>(spec).await?;
160        let ListEnvelope {
161            object,
162            data,
163            first_id,
164            last_id,
165            has_more,
166            extra,
167        } = response.data;
168        let next_query = last_id
169            .as_ref()
170            .map(|last_id| {
171                let mut query = BTreeMap::new();
172                query.insert("after".into(), Some(last_id.clone()));
173                query
174            })
175            .unwrap_or_default();
176        let page_value = CursorPage::from(ListEnvelope {
177            object,
178            data,
179            first_id,
180            last_id,
181            has_more,
182            extra,
183        });
184        Ok(page_value.with_next_request(if has_more {
185            Some(PageRequestSpec {
186                client: self.clone(),
187                endpoint_id: page.endpoint_id,
188                method,
189                path: page.path,
190                query: next_query,
191            })
192        } else {
193            None
194        }))
195    }
196
197    /// 返回顶层 completions 资源。
198    pub fn completions(&self) -> CompletionsResource {
199        CompletionsResource::new(self.clone())
200    }
201
202    /// 返回 chat 命名空间。
203    pub fn chat(&self) -> ChatResource {
204        ChatResource::new(self.clone())
205    }
206
207    /// 返回 embeddings 资源。
208    pub fn embeddings(&self) -> EmbeddingsResource {
209        EmbeddingsResource::new(self.clone())
210    }
211
212    /// 返回 files 资源。
213    pub fn files(&self) -> FilesResource {
214        FilesResource::new(self.clone())
215    }
216
217    /// 返回 images 资源。
218    pub fn images(&self) -> ImagesResource {
219        ImagesResource::new(self.clone())
220    }
221
222    /// 返回 audio 命名空间。
223    pub fn audio(&self) -> AudioResource {
224        AudioResource::new(self.clone())
225    }
226
227    /// 返回 moderations 资源。
228    pub fn moderations(&self) -> ModerationsResource {
229        ModerationsResource::new(self.clone())
230    }
231
232    /// 返回 models 资源。
233    pub fn models(&self) -> ModelsResource {
234        ModelsResource::new(self.clone())
235    }
236
237    /// 返回 fine_tuning 命名空间。
238    pub fn fine_tuning(&self) -> FineTuningResource {
239        FineTuningResource::new(self.clone())
240    }
241
242    /// 返回 graders 命名空间。
243    pub fn graders(&self) -> GradersResource {
244        GradersResource::new(self.clone())
245    }
246
247    /// 返回 vector_stores 资源。
248    pub fn vector_stores(&self) -> VectorStoresResource {
249        VectorStoresResource::new(self.clone())
250    }
251
252    /// 返回 webhooks 资源。
253    pub fn webhooks(&self) -> WebhooksResource {
254        WebhooksResource::new(self.clone())
255    }
256
257    /// 返回 batches 资源。
258    pub fn batches(&self) -> BatchesResource {
259        BatchesResource::new(self.clone())
260    }
261
262    /// 返回 uploads 资源。
263    pub fn uploads(&self) -> UploadsResource {
264        UploadsResource::new(self.clone())
265    }
266
267    /// 返回 responses 资源。
268    pub fn responses(&self) -> ResponsesResource {
269        ResponsesResource::new(self.clone())
270    }
271
272    /// 返回 realtime 资源。
273    pub fn realtime(&self) -> RealtimeResource {
274        RealtimeResource::new(self.clone())
275    }
276
277    /// 返回 conversations 资源。
278    pub fn conversations(&self) -> ConversationsResource {
279        ConversationsResource::new(self.clone())
280    }
281
282    /// 返回 evals 资源。
283    pub fn evals(&self) -> EvalsResource {
284        EvalsResource::new(self.clone())
285    }
286
287    /// 返回 containers 资源。
288    pub fn containers(&self) -> ContainersResource {
289        ContainersResource::new(self.clone())
290    }
291
292    /// 返回 skills 资源。
293    pub fn skills(&self) -> SkillsResource {
294        SkillsResource::new(self.clone())
295    }
296
297    /// 返回 videos 资源。
298    pub fn videos(&self) -> VideosResource {
299        VideosResource::new(self.clone())
300    }
301
302    /// 返回 beta 命名空间。
303    pub fn beta(&self) -> BetaResource {
304        BetaResource::new(self.clone())
305    }
306}
307
308impl ClientInner {
309    pub(crate) fn base_url(&self) -> &str {
310        self.options
311            .base_url
312            .as_deref()
313            .unwrap_or_else(|| self.provider.default_base_url())
314    }
315
316    pub(crate) fn log(
317        &self,
318        level: LogLevel,
319        target: &'static str,
320        message: impl Into<String>,
321        fields: BTreeMap<String, String>,
322    ) {
323        if !self.options.log_level.allows(level) {
324            return;
325        }
326
327        let record = LogRecord {
328            level,
329            target,
330            message: message.into(),
331            fields,
332        };
333
334        if let Some(logger) = &self.options.logger {
335            logger.log(&record);
336        }
337
338        let rendered_fields = if record.fields.is_empty() {
339            String::new()
340        } else {
341            format!(
342                " {}",
343                record
344                    .fields
345                    .iter()
346                    .map(|(key, value)| format!("{key}={value}"))
347                    .collect::<Vec<_>>()
348                    .join(" ")
349            )
350        };
351        let rendered = format!("[{}] {}{}", target, record.message, rendered_fields);
352        match level {
353            LogLevel::Off => {}
354            LogLevel::Error => error!("{rendered}"),
355            LogLevel::Warn => warn!("{rendered}"),
356            LogLevel::Info => info!("{rendered}"),
357            LogLevel::Debug => debug!("{rendered}"),
358        }
359    }
360}
361
362impl ClientBuilder {
363    /// 从环境变量构建默认配置。
364    pub fn from_env() -> Self {
365        let mut builder = Self::default();
366
367        if let Some(webhook_secret) = read_env("OPENAI_WEBHOOK_SECRET") {
368            builder.options.webhook_secret = Some(SecretString::new(webhook_secret.into()));
369        }
370        if let Some(log_level) =
371            read_env("OPENAI_LOG").and_then(|value| value.parse::<LogLevel>().ok())
372        {
373            builder.options.log_level = log_level;
374        }
375
376        if let Some(azure_endpoint) = read_env("AZURE_OPENAI_ENDPOINT") {
377            builder = builder.azure_endpoint(azure_endpoint);
378            if let Some(api_version) = read_env("OPENAI_API_VERSION") {
379                builder = builder.azure_api_version(api_version);
380            }
381            if let Some(api_key) = read_env("AZURE_OPENAI_API_KEY") {
382                builder = builder.api_key(api_key);
383            }
384            return builder;
385        }
386
387        if let Some(base_url) = read_env("OPENAI_BASE_URL") {
388            builder.options.base_url = Some(base_url);
389        }
390        if let Some(api_key) = read_env("OPENAI_API_KEY") {
391            builder.api_key_source = Some(ApiKeySource::from_static(api_key));
392        }
393
394        builder
395    }
396
397    /// 设置 Provider。
398    pub fn provider(mut self, provider: Provider) -> Self {
399        if provider.kind() != ProviderKind::Azure {
400            self.azure_options = AzureOptions::default();
401            self.azure_endpoint = None;
402            self.azure_configured = false;
403        }
404        self.options.provider = provider;
405        self
406    }
407
408    /// 注入一个自定义 `reqwest::Client`。
409    pub fn http_client(mut self, client: reqwest::Client) -> Self {
410        self.http_client = Some(client);
411        self
412    }
413
414    /// 设置 SDK 内部日志级别。
415    pub fn log_level(mut self, log_level: LogLevel) -> Self {
416        self.options.log_level = log_level;
417        self
418    }
419
420    /// 注入一个用户自定义日志器。
421    pub fn logger<L>(mut self, logger: L) -> Self
422    where
423        L: Logger + 'static,
424    {
425        self.options.logger = Some(LoggerHandle::new(logger));
426        self
427    }
428
429    /// 设置静态 API Key。
430    pub fn api_key<T>(mut self, api_key: T) -> Self
431    where
432        T: Into<String>,
433    {
434        self.api_key_source = Some(ApiKeySource::from_static(api_key));
435        self
436    }
437
438    /// 设置动态 API Key 回调。
439    pub fn api_key_provider<F>(mut self, provider: F) -> Self
440    where
441        F: Fn() -> Result<SecretString> + Send + Sync + 'static,
442    {
443        self.api_key_source = Some(ApiKeySource::from_provider(provider));
444        self
445    }
446
447    /// 设置异步 API Key 回调。
448    pub fn api_key_async_provider<F, Fut>(mut self, provider: F) -> Self
449    where
450        F: Fn() -> Fut + Send + Sync + 'static,
451        Fut: Future<Output = Result<SecretString>> + Send + 'static,
452    {
453        self.api_key_source = Some(ApiKeySource::from_async_provider(provider));
454        self
455    }
456
457    /// 覆盖基础地址。
458    pub fn base_url<T>(mut self, base_url: T) -> Self
459    where
460        T: Into<String>,
461    {
462        self.options.base_url = Some(base_url.into());
463        self
464    }
465
466    /// 控制当 `base_url` 指向本机地址时,是否显式关闭系统代理。
467    ///
468    /// 该开关默认关闭。
469    pub fn disable_proxy_for_local_base_url(mut self, disable: bool) -> Self {
470        self.options.disable_proxy_for_local_base_url = disable;
471        self
472    }
473
474    /// 设置 Azure 资源级 endpoint。
475    ///
476    /// 该值应类似 `https://example-resource.openai.azure.com`,
477    /// SDK 会在发送请求时自动补上 `/openai`。
478    pub fn azure_endpoint<T>(mut self, endpoint: T) -> Self
479    where
480        T: Into<String>,
481    {
482        self.azure_endpoint = Some(endpoint.into());
483        self.azure_configured = true;
484        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
485        self
486    }
487
488    /// 设置 Azure `api-version`。
489    pub fn azure_api_version<T>(mut self, api_version: T) -> Self
490    where
491        T: Into<String>,
492    {
493        self.azure_options.api_version = Some(api_version.into());
494        self.azure_configured = true;
495        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
496        self
497    }
498
499    /// 设置 Azure 默认 deployment。
500    pub fn azure_deployment<T>(mut self, deployment: T) -> Self
501    where
502        T: Into<String>,
503    {
504        self.azure_options.deployment = Some(deployment.into());
505        self.azure_configured = true;
506        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
507        self
508    }
509
510    /// 切换 Azure 为 Bearer Token 认证。
511    pub fn azure_bearer_auth(mut self) -> Self {
512        self.azure_options = self.azure_options.bearer_auth();
513        self.azure_configured = true;
514        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
515        self
516    }
517
518    /// 设置 Azure AD Bearer Token。
519    pub fn azure_ad_token<T>(mut self, token: T) -> Self
520    where
521        T: Into<String>,
522    {
523        self.azure_options = self.azure_options.bearer_auth();
524        self.azure_configured = true;
525        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
526        self.api_key_source = Some(ApiKeySource::from_static(token));
527        self
528    }
529
530    /// 设置 Azure AD Bearer Token 异步提供器。
531    pub fn azure_ad_token_provider<F, Fut>(mut self, provider: F) -> Self
532    where
533        F: Fn() -> Fut + Send + Sync + 'static,
534        Fut: Future<Output = Result<SecretString>> + Send + 'static,
535    {
536        self.azure_options = self.azure_options.bearer_auth();
537        self.azure_configured = true;
538        self.options.provider = Provider::azure_with_options(self.azure_options.clone());
539        self.api_key_source = Some(ApiKeySource::from_async_provider(provider));
540        self
541    }
542
543    /// 覆盖默认超时时间。
544    pub fn timeout(mut self, timeout: Duration) -> Self {
545        self.options.timeout = timeout;
546        self
547    }
548
549    /// 覆盖默认最大重试次数。
550    pub fn max_retries(mut self, max_retries: u32) -> Self {
551        self.options.max_retries = max_retries;
552        self
553    }
554
555    /// 添加默认请求头。
556    pub fn default_header<T, U>(mut self, key: T, value: U) -> Self
557    where
558        T: Into<String>,
559        U: Into<String>,
560    {
561        self.options
562            .default_headers
563            .insert(key.into(), value.into());
564        self
565    }
566
567    /// 批量设置默认请求头。
568    pub fn default_headers(mut self, headers: BTreeMap<String, String>) -> Self {
569        self.options.default_headers = headers;
570        self
571    }
572
573    /// 添加默认查询参数。
574    pub fn default_query<T, U>(mut self, key: T, value: U) -> Self
575    where
576        T: Into<String>,
577        U: Into<String>,
578    {
579        self.options.default_query.insert(key.into(), value.into());
580        self
581    }
582
583    /// 批量设置默认查询参数。
584    pub fn default_query_map(mut self, query: BTreeMap<String, String>) -> Self {
585        self.options.default_query = query;
586        self
587    }
588
589    /// 设置 Webhook 密钥。
590    pub fn webhook_secret<T>(mut self, secret: T) -> Self
591    where
592        T: Into<String>,
593    {
594        self.options.webhook_secret = Some(SecretString::new(secret.into().into()));
595        self
596    }
597
598    /// 设置兼容性模式。
599    pub fn compatibility_mode(mut self, mode: CompatibilityMode) -> Self {
600        self.options.compatibility_mode = mode;
601        self
602    }
603
604    /// 构建客户端。
605    ///
606    /// # Errors
607    ///
608    /// 当基础地址非法或底层 `reqwest::Client` 初始化失败时返回错误。
609    pub fn build(self) -> Result<Client> {
610        let mut options = self.options;
611        if options.provider.kind() == ProviderKind::Azure
612            && (self.azure_configured || self.azure_endpoint.is_some())
613        {
614            options.provider = Provider::azure_with_options(self.azure_options.clone());
615            if let Some(endpoint) = self.azure_endpoint {
616                if options.base_url.is_some() {
617                    return Err(Error::InvalidConfig(
618                        "`base_url` 和 `azure_endpoint` 不能同时设置".into(),
619                    ));
620                }
621                options.base_url = Some(endpoint);
622            }
623        }
624
625        let http = if let Some(client) = self.http_client {
626            client
627        } else {
628            let mut default_headers = reqwest::header::HeaderMap::new();
629            default_headers.insert(
630                reqwest::header::USER_AGENT,
631                reqwest::header::HeaderValue::from_static(concat!(
632                    env!("CARGO_PKG_NAME"),
633                    "/",
634                    env!("CARGO_PKG_VERSION")
635                )),
636            );
637
638            let mut builder = reqwest::Client::builder().default_headers(default_headers);
639            if options.disable_proxy_for_local_base_url
640                && should_disable_proxy_for_base_url(options.base_url.as_deref())
641            {
642                builder = builder.no_proxy();
643            }
644
645            builder
646                .build()
647                .map_err(|error| Error::InvalidConfig(format!("创建 HTTP 客户端失败: {error}")))?
648        };
649        Ok(Client::from_parts(
650            http,
651            options.provider.clone(),
652            self.api_key_source,
653            options,
654        ))
655    }
656}
657
658fn read_env(key: &str) -> Option<String> {
659    env::var(key)
660        .ok()
661        .map(|value| value.trim().to_owned())
662        .filter(|value| !value.is_empty())
663}
664
665fn should_disable_proxy_for_base_url(base_url: Option<&str>) -> bool {
666    let Some(base_url) = base_url else {
667        return false;
668    };
669
670    let Ok(url) = url::Url::parse(base_url) else {
671        return false;
672    };
673
674    matches!(
675        url.host_str(),
676        Some("localhost") | Some("127.0.0.1") | Some("[::1]") | Some("::1")
677    )
678}