1use std::collections::BTreeMap;
4use std::env;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8
9use secrecy::SecretString;
10use tracing::{debug, error, info, warn};
11
12use crate::auth::ApiKeySource;
13use crate::config::{ClientOptions, LogLevel, LogRecord, Logger, LoggerHandle};
14use crate::error::{Error, Result};
15use crate::pagination::{CursorPage, ListEnvelope};
16use crate::providers::{AzureOptions, CompatibilityMode, Provider, ProviderKind};
17use crate::resources::{
18 AdminResource, AudioResource, BatchesResource, BetaResource, ChatResource, CompletionsResource,
19 ContainersResource, ConversationsResource, EmbeddingsResource, EvalsResource, FilesResource,
20 FineTuningResource, GradersResource, ImagesResource, ModelsResource, ModerationsResource,
21 RealtimeResource, ResponsesResource, SkillsResource, UploadsResource, VectorStoresResource,
22 VideosResource, WebhooksResource,
23};
24use crate::transport::{
25 RequestSpec, execute_bytes, execute_json, execute_raw_http, execute_raw_sse, execute_sse,
26};
27use crate::{ApiResponse, RawSseStream, SseStream};
28
29#[derive(Debug, Clone)]
31pub struct Client {
32 pub(crate) inner: Arc<ClientInner>,
33}
34
35#[derive(Debug)]
37pub(crate) struct ClientInner {
38 pub(crate) http: reqwest::Client,
39 pub(crate) options: ClientOptions,
40 pub(crate) api_key_source: Option<ApiKeySource>,
41 pub(crate) admin_api_key_source: Option<ApiKeySource>,
42 pub(crate) provider: Provider,
43}
44
45#[derive(Debug, Clone)]
47pub struct PageRequestSpec {
48 pub client: Client,
50 pub endpoint_id: &'static str,
52 pub method: http::Method,
54 pub path: String,
56 pub query: BTreeMap<String, Option<String>>,
58}
59
60#[derive(Debug, Clone, Default)]
62pub struct ClientBuilder {
63 options: ClientOptions,
64 api_key_source: Option<ApiKeySource>,
65 admin_api_key_source: Option<ApiKeySource>,
66 azure_options: AzureOptions,
67 azure_endpoint: Option<String>,
68 azure_configured: bool,
69 http_client: Option<reqwest::Client>,
70}
71
72impl Client {
73 pub fn builder() -> ClientBuilder {
75 ClientBuilder::from_env()
76 }
77
78 pub fn provider(&self) -> &Provider {
80 &self.inner.provider
81 }
82
83 pub fn base_url(&self) -> &str {
85 self.inner.base_url()
86 }
87
88 pub fn with_options<F>(&self, mutate: F) -> Self
90 where
91 F: FnOnce(&mut ClientOptions),
92 {
93 let mut options = self.inner.options.clone();
94 mutate(&mut options);
95 Self::from_parts(
96 self.inner.http.clone(),
97 options.provider.clone(),
98 self.inner.api_key_source.clone(),
99 self.inner.admin_api_key_source.clone(),
100 options,
101 )
102 }
103
104 pub(crate) fn from_parts(
105 http: reqwest::Client,
106 provider: Provider,
107 api_key_source: Option<ApiKeySource>,
108 admin_api_key_source: Option<ApiKeySource>,
109 mut options: ClientOptions,
110 ) -> Self {
111 options.provider = provider.clone();
112 Self {
113 inner: Arc::new(ClientInner {
114 http,
115 options,
116 api_key_source,
117 admin_api_key_source,
118 provider,
119 }),
120 }
121 }
122
123 pub(crate) async fn execute_json<T>(&self, spec: RequestSpec) -> Result<ApiResponse<T>>
124 where
125 T: serde::de::DeserializeOwned,
126 {
127 execute_json(&self.inner, spec).await
128 }
129
130 pub(crate) async fn execute_bytes(
131 &self,
132 spec: RequestSpec,
133 ) -> Result<ApiResponse<bytes::Bytes>> {
134 execute_bytes(&self.inner, spec).await
135 }
136
137 pub(crate) async fn execute_sse<T>(&self, spec: RequestSpec) -> Result<SseStream<T>>
138 where
139 T: serde::de::DeserializeOwned + Send + 'static,
140 {
141 execute_sse(&self.inner, spec).await
142 }
143
144 #[allow(dead_code)]
145 pub(crate) async fn execute_raw_sse(&self, spec: RequestSpec) -> Result<RawSseStream> {
146 execute_raw_sse(&self.inner, spec).await
147 }
148
149 pub(crate) async fn execute_raw_http(
150 &self,
151 spec: RequestSpec,
152 ) -> Result<http::Response<bytes::Bytes>> {
153 execute_raw_http(&self.inner, spec).await
154 }
155
156 pub(crate) async fn fetch_cursor_page<T>(&self, page: PageRequestSpec) -> Result<CursorPage<T>>
157 where
158 T: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
159 {
160 let method = page.method.clone();
161 let mut spec = RequestSpec::new(page.endpoint_id, method.clone(), page.path.clone());
162 spec.options.extra_query = page.query;
163
164 let response = self.execute_json::<ListEnvelope<T>>(spec).await?;
165 let ListEnvelope {
166 object,
167 data,
168 first_id,
169 last_id,
170 has_more,
171 extra,
172 } = response.data;
173 let next_query = last_id
174 .as_ref()
175 .map(|last_id| {
176 let mut query = BTreeMap::new();
177 query.insert("after".into(), Some(last_id.clone()));
178 query
179 })
180 .unwrap_or_default();
181 let page_value = CursorPage::from(ListEnvelope {
182 object,
183 data,
184 first_id,
185 last_id,
186 has_more,
187 extra,
188 });
189 Ok(page_value.with_next_request(if has_more {
190 Some(PageRequestSpec {
191 client: self.clone(),
192 endpoint_id: page.endpoint_id,
193 method,
194 path: page.path,
195 query: next_query,
196 })
197 } else {
198 None
199 }))
200 }
201
202 pub fn completions(&self) -> CompletionsResource {
204 CompletionsResource::new(self.clone())
205 }
206
207 pub fn admin(&self) -> AdminResource {
209 AdminResource::new(self.clone())
210 }
211
212 pub fn chat(&self) -> ChatResource {
214 ChatResource::new(self.clone())
215 }
216
217 pub fn embeddings(&self) -> EmbeddingsResource {
219 EmbeddingsResource::new(self.clone())
220 }
221
222 pub fn files(&self) -> FilesResource {
224 FilesResource::new(self.clone())
225 }
226
227 pub fn images(&self) -> ImagesResource {
229 ImagesResource::new(self.clone())
230 }
231
232 pub fn audio(&self) -> AudioResource {
234 AudioResource::new(self.clone())
235 }
236
237 pub fn moderations(&self) -> ModerationsResource {
239 ModerationsResource::new(self.clone())
240 }
241
242 pub fn models(&self) -> ModelsResource {
244 ModelsResource::new(self.clone())
245 }
246
247 pub fn fine_tuning(&self) -> FineTuningResource {
249 FineTuningResource::new(self.clone())
250 }
251
252 pub fn graders(&self) -> GradersResource {
254 GradersResource::new(self.clone())
255 }
256
257 pub fn vector_stores(&self) -> VectorStoresResource {
259 VectorStoresResource::new(self.clone())
260 }
261
262 pub fn webhooks(&self) -> WebhooksResource {
264 WebhooksResource::new(self.clone())
265 }
266
267 pub fn batches(&self) -> BatchesResource {
269 BatchesResource::new(self.clone())
270 }
271
272 pub fn uploads(&self) -> UploadsResource {
274 UploadsResource::new(self.clone())
275 }
276
277 pub fn responses(&self) -> ResponsesResource {
279 ResponsesResource::new(self.clone())
280 }
281
282 pub fn realtime(&self) -> RealtimeResource {
284 RealtimeResource::new(self.clone())
285 }
286
287 pub fn conversations(&self) -> ConversationsResource {
289 ConversationsResource::new(self.clone())
290 }
291
292 pub fn evals(&self) -> EvalsResource {
294 EvalsResource::new(self.clone())
295 }
296
297 pub fn containers(&self) -> ContainersResource {
299 ContainersResource::new(self.clone())
300 }
301
302 pub fn skills(&self) -> SkillsResource {
304 SkillsResource::new(self.clone())
305 }
306
307 pub fn videos(&self) -> VideosResource {
309 VideosResource::new(self.clone())
310 }
311
312 pub fn beta(&self) -> BetaResource {
314 BetaResource::new(self.clone())
315 }
316}
317
318impl ClientInner {
319 pub(crate) fn base_url(&self) -> &str {
320 self.options
321 .base_url
322 .as_deref()
323 .unwrap_or_else(|| self.provider.default_base_url())
324 }
325
326 pub(crate) fn log(
327 &self,
328 level: LogLevel,
329 target: &'static str,
330 message: impl Into<String>,
331 fields: BTreeMap<String, String>,
332 ) {
333 if !self.options.log_level.allows(level) {
334 return;
335 }
336
337 let record = LogRecord {
338 level,
339 target,
340 message: message.into(),
341 fields,
342 };
343
344 if let Some(logger) = &self.options.logger {
345 logger.log(&record);
346 }
347
348 let rendered_fields = if record.fields.is_empty() {
349 String::new()
350 } else {
351 format!(
352 " {}",
353 record
354 .fields
355 .iter()
356 .map(|(key, value)| format!("{key}={value}"))
357 .collect::<Vec<_>>()
358 .join(" ")
359 )
360 };
361 let rendered = format!("[{}] {}{}", target, record.message, rendered_fields);
362 match level {
363 LogLevel::Off => {}
364 LogLevel::Error => error!("{rendered}"),
365 LogLevel::Warn => warn!("{rendered}"),
366 LogLevel::Info => info!("{rendered}"),
367 LogLevel::Debug => debug!("{rendered}"),
368 }
369 }
370}
371
372impl ClientBuilder {
373 pub fn from_env() -> Self {
375 let mut builder = Self::default();
376
377 if let Some(webhook_secret) = read_env("OPENAI_WEBHOOK_SECRET") {
378 builder.options.webhook_secret = Some(SecretString::new(webhook_secret.into()));
379 }
380 if let Some(log_level) =
381 read_env("OPENAI_LOG").and_then(|value| value.parse::<LogLevel>().ok())
382 {
383 builder.options.log_level = log_level;
384 }
385
386 if let Some(azure_endpoint) = read_env("AZURE_OPENAI_ENDPOINT") {
387 builder = builder.azure_endpoint(azure_endpoint);
388 if let Some(api_version) = read_env("OPENAI_API_VERSION") {
389 builder = builder.azure_api_version(api_version);
390 }
391 if let Some(api_key) = read_env("AZURE_OPENAI_API_KEY") {
392 builder = builder.api_key(api_key);
393 }
394 return builder;
395 }
396
397 if let Some(base_url) = read_env("OPENAI_BASE_URL") {
398 builder.options.base_url = Some(base_url);
399 }
400 if let Some(api_key) = read_env("OPENAI_API_KEY") {
401 builder.api_key_source = Some(ApiKeySource::from_static(api_key));
402 }
403 if let Some(admin_api_key) = read_env("OPENAI_ADMIN_KEY") {
404 builder.admin_api_key_source = Some(ApiKeySource::from_static(admin_api_key));
405 }
406
407 builder
408 }
409
410 pub fn provider(mut self, provider: Provider) -> Self {
412 if provider.kind() != ProviderKind::Azure {
413 self.azure_options = AzureOptions::default();
414 self.azure_endpoint = None;
415 self.azure_configured = false;
416 }
417 self.options.provider = provider;
418 self
419 }
420
421 pub fn http_client(mut self, client: reqwest::Client) -> Self {
423 self.http_client = Some(client);
424 self
425 }
426
427 pub fn log_level(mut self, log_level: LogLevel) -> Self {
429 self.options.log_level = log_level;
430 self
431 }
432
433 pub fn logger<L>(mut self, logger: L) -> Self
435 where
436 L: Logger + 'static,
437 {
438 self.options.logger = Some(LoggerHandle::new(logger));
439 self
440 }
441
442 pub fn api_key<T>(mut self, api_key: T) -> Self
444 where
445 T: Into<String>,
446 {
447 self.api_key_source = Some(ApiKeySource::from_static(api_key));
448 self
449 }
450
451 pub fn api_key_provider<F>(mut self, provider: F) -> Self
453 where
454 F: Fn() -> Result<SecretString> + Send + Sync + 'static,
455 {
456 self.api_key_source = Some(ApiKeySource::from_provider(provider));
457 self
458 }
459
460 pub fn api_key_async_provider<F, Fut>(mut self, provider: F) -> Self
462 where
463 F: Fn() -> Fut + Send + Sync + 'static,
464 Fut: Future<Output = Result<SecretString>> + Send + 'static,
465 {
466 self.api_key_source = Some(ApiKeySource::from_async_provider(provider));
467 self
468 }
469
470 pub fn admin_api_key<T>(mut self, admin_api_key: T) -> Self
472 where
473 T: Into<String>,
474 {
475 self.admin_api_key_source = Some(ApiKeySource::from_static(admin_api_key));
476 self
477 }
478
479 pub fn admin_api_key_provider<F>(mut self, provider: F) -> Self
481 where
482 F: Fn() -> Result<SecretString> + Send + Sync + 'static,
483 {
484 self.admin_api_key_source = Some(ApiKeySource::from_provider(provider));
485 self
486 }
487
488 pub fn admin_api_key_async_provider<F, Fut>(mut self, provider: F) -> Self
490 where
491 F: Fn() -> Fut + Send + Sync + 'static,
492 Fut: Future<Output = Result<SecretString>> + Send + 'static,
493 {
494 self.admin_api_key_source = Some(ApiKeySource::from_async_provider(provider));
495 self
496 }
497
498 pub fn base_url<T>(mut self, base_url: T) -> Self
500 where
501 T: Into<String>,
502 {
503 self.options.base_url = Some(base_url.into());
504 self
505 }
506
507 pub fn disable_proxy_for_local_base_url(mut self, disable: bool) -> Self {
511 self.options.disable_proxy_for_local_base_url = disable;
512 self
513 }
514
515 pub fn azure_endpoint<T>(mut self, endpoint: T) -> Self
520 where
521 T: Into<String>,
522 {
523 self.azure_endpoint = Some(endpoint.into());
524 self.azure_configured = true;
525 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
526 self
527 }
528
529 pub fn azure_api_version<T>(mut self, api_version: T) -> Self
531 where
532 T: Into<String>,
533 {
534 self.azure_options.api_version = Some(api_version.into());
535 self.azure_configured = true;
536 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
537 self
538 }
539
540 pub fn azure_deployment<T>(mut self, deployment: T) -> Self
542 where
543 T: Into<String>,
544 {
545 self.azure_options.deployment = Some(deployment.into());
546 self.azure_configured = true;
547 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
548 self
549 }
550
551 pub fn azure_bearer_auth(mut self) -> Self {
553 self.azure_options = self.azure_options.bearer_auth();
554 self.azure_configured = true;
555 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
556 self
557 }
558
559 pub fn azure_ad_token<T>(mut self, token: T) -> Self
561 where
562 T: Into<String>,
563 {
564 self.azure_options = self.azure_options.bearer_auth();
565 self.azure_configured = true;
566 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
567 self.api_key_source = Some(ApiKeySource::from_static(token));
568 self
569 }
570
571 pub fn azure_ad_token_provider<F, Fut>(mut self, provider: F) -> Self
573 where
574 F: Fn() -> Fut + Send + Sync + 'static,
575 Fut: Future<Output = Result<SecretString>> + Send + 'static,
576 {
577 self.azure_options = self.azure_options.bearer_auth();
578 self.azure_configured = true;
579 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
580 self.api_key_source = Some(ApiKeySource::from_async_provider(provider));
581 self
582 }
583
584 pub fn timeout(mut self, timeout: Duration) -> Self {
586 self.options.timeout = timeout;
587 self
588 }
589
590 pub fn max_retries(mut self, max_retries: u32) -> Self {
592 self.options.max_retries = max_retries;
593 self
594 }
595
596 pub fn default_header<T, U>(mut self, key: T, value: U) -> Self
598 where
599 T: Into<String>,
600 U: Into<String>,
601 {
602 self.options
603 .default_headers
604 .insert(key.into(), value.into());
605 self
606 }
607
608 pub fn default_headers(mut self, headers: BTreeMap<String, String>) -> Self {
610 self.options.default_headers = headers;
611 self
612 }
613
614 pub fn default_query<T, U>(mut self, key: T, value: U) -> Self
616 where
617 T: Into<String>,
618 U: Into<String>,
619 {
620 self.options.default_query.insert(key.into(), value.into());
621 self
622 }
623
624 pub fn default_query_map(mut self, query: BTreeMap<String, String>) -> Self {
626 self.options.default_query = query;
627 self
628 }
629
630 pub fn webhook_secret<T>(mut self, secret: T) -> Self
632 where
633 T: Into<String>,
634 {
635 self.options.webhook_secret = Some(SecretString::new(secret.into().into()));
636 self
637 }
638
639 pub fn compatibility_mode(mut self, mode: CompatibilityMode) -> Self {
641 self.options.compatibility_mode = mode;
642 self
643 }
644
645 pub fn build(self) -> Result<Client> {
651 let mut options = self.options;
652 if options.provider.kind() == ProviderKind::Azure
653 && (self.azure_configured || self.azure_endpoint.is_some())
654 {
655 options.provider = Provider::azure_with_options(self.azure_options.clone());
656 if let Some(endpoint) = self.azure_endpoint {
657 if options.base_url.is_some() {
658 return Err(Error::InvalidConfig(
659 "`base_url` 和 `azure_endpoint` 不能同时设置".into(),
660 ));
661 }
662 options.base_url = Some(endpoint);
663 }
664 }
665
666 let http = if let Some(client) = self.http_client {
667 client
668 } else {
669 let mut default_headers = reqwest::header::HeaderMap::new();
670 default_headers.insert(
671 reqwest::header::USER_AGENT,
672 reqwest::header::HeaderValue::from_static(concat!(
673 env!("CARGO_PKG_NAME"),
674 "/",
675 env!("CARGO_PKG_VERSION")
676 )),
677 );
678
679 let mut builder = reqwest::Client::builder().default_headers(default_headers);
680 if options.disable_proxy_for_local_base_url
681 && should_disable_proxy_for_base_url(options.base_url.as_deref())
682 {
683 builder = builder.no_proxy();
684 }
685
686 builder
687 .build()
688 .map_err(|error| Error::InvalidConfig(format!("创建 HTTP 客户端失败: {error}")))?
689 };
690 Ok(Client::from_parts(
691 http,
692 options.provider.clone(),
693 self.api_key_source,
694 self.admin_api_key_source,
695 options,
696 ))
697 }
698}
699
700fn read_env(key: &str) -> Option<String> {
701 env::var(key)
702 .ok()
703 .map(|value| value.trim().to_owned())
704 .filter(|value| !value.is_empty())
705}
706
707fn should_disable_proxy_for_base_url(base_url: Option<&str>) -> bool {
708 let Some(base_url) = base_url else {
709 return false;
710 };
711
712 let Ok(url) = url::Url::parse(base_url) else {
713 return false;
714 };
715
716 matches!(
717 url.host_str(),
718 Some("localhost") | Some("127.0.0.1") | Some("[::1]") | Some("::1")
719 )
720}