1use std::collections::BTreeMap;
4use std::env;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8
9use secrecy::SecretString;
10use tracing::{debug, error, info, warn};
11
12use crate::auth::ApiKeySource;
13use crate::config::{ClientOptions, LogLevel, LogRecord, Logger, LoggerHandle};
14use crate::error::{Error, Result};
15use crate::pagination::{CursorPage, ListEnvelope};
16use crate::providers::{AzureOptions, CompatibilityMode, Provider, ProviderKind};
17use crate::resources::{
18 AudioResource, BatchesResource, BetaResource, ChatResource, CompletionsResource,
19 ContainersResource, ConversationsResource, EmbeddingsResource, EvalsResource, FilesResource,
20 FineTuningResource, GradersResource, ImagesResource, ModelsResource, ModerationsResource,
21 RealtimeResource, ResponsesResource, SkillsResource, UploadsResource, VectorStoresResource,
22 VideosResource, WebhooksResource,
23};
24use crate::transport::{
25 RequestSpec, execute_bytes, execute_json, execute_raw_http, execute_raw_sse, execute_sse,
26};
27use crate::{ApiResponse, RawSseStream, SseStream};
28
29#[derive(Debug, Clone)]
31pub struct Client {
32 pub(crate) inner: Arc<ClientInner>,
33}
34
35#[derive(Debug)]
37pub(crate) struct ClientInner {
38 pub(crate) http: reqwest::Client,
39 pub(crate) options: ClientOptions,
40 pub(crate) api_key_source: Option<ApiKeySource>,
41 pub(crate) provider: Provider,
42}
43
44#[derive(Debug, Clone)]
46pub struct PageRequestSpec {
47 pub client: Client,
49 pub endpoint_id: &'static str,
51 pub method: http::Method,
53 pub path: String,
55 pub query: BTreeMap<String, Option<String>>,
57}
58
59#[derive(Debug, Clone, Default)]
61pub struct ClientBuilder {
62 options: ClientOptions,
63 api_key_source: Option<ApiKeySource>,
64 azure_options: AzureOptions,
65 azure_endpoint: Option<String>,
66 azure_configured: bool,
67 http_client: Option<reqwest::Client>,
68}
69
70impl Client {
71 pub fn builder() -> ClientBuilder {
73 ClientBuilder::from_env()
74 }
75
76 pub fn provider(&self) -> &Provider {
78 &self.inner.provider
79 }
80
81 pub fn base_url(&self) -> &str {
83 self.inner.base_url()
84 }
85
86 pub fn with_options<F>(&self, mutate: F) -> Self
88 where
89 F: FnOnce(&mut ClientOptions),
90 {
91 let mut options = self.inner.options.clone();
92 mutate(&mut options);
93 Self::from_parts(
94 self.inner.http.clone(),
95 options.provider.clone(),
96 self.inner.api_key_source.clone(),
97 options,
98 )
99 }
100
101 pub(crate) fn from_parts(
102 http: reqwest::Client,
103 provider: Provider,
104 api_key_source: Option<ApiKeySource>,
105 mut options: ClientOptions,
106 ) -> Self {
107 options.provider = provider.clone();
108 Self {
109 inner: Arc::new(ClientInner {
110 http,
111 options,
112 api_key_source,
113 provider,
114 }),
115 }
116 }
117
118 pub(crate) async fn execute_json<T>(&self, spec: RequestSpec) -> Result<ApiResponse<T>>
119 where
120 T: serde::de::DeserializeOwned,
121 {
122 execute_json(&self.inner, spec).await
123 }
124
125 pub(crate) async fn execute_bytes(
126 &self,
127 spec: RequestSpec,
128 ) -> Result<ApiResponse<bytes::Bytes>> {
129 execute_bytes(&self.inner, spec).await
130 }
131
132 pub(crate) async fn execute_sse<T>(&self, spec: RequestSpec) -> Result<SseStream<T>>
133 where
134 T: serde::de::DeserializeOwned + Send + 'static,
135 {
136 execute_sse(&self.inner, spec).await
137 }
138
139 #[allow(dead_code)]
140 pub(crate) async fn execute_raw_sse(&self, spec: RequestSpec) -> Result<RawSseStream> {
141 execute_raw_sse(&self.inner, spec).await
142 }
143
144 pub(crate) async fn execute_raw_http(
145 &self,
146 spec: RequestSpec,
147 ) -> Result<http::Response<bytes::Bytes>> {
148 execute_raw_http(&self.inner, spec).await
149 }
150
151 pub(crate) async fn fetch_cursor_page<T>(&self, page: PageRequestSpec) -> Result<CursorPage<T>>
152 where
153 T: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
154 {
155 let method = page.method.clone();
156 let mut spec = RequestSpec::new(page.endpoint_id, method.clone(), page.path.clone());
157 spec.options.extra_query = page.query;
158
159 let response = self.execute_json::<ListEnvelope<T>>(spec).await?;
160 let ListEnvelope {
161 object,
162 data,
163 first_id,
164 last_id,
165 has_more,
166 extra,
167 } = response.data;
168 let next_query = last_id
169 .as_ref()
170 .map(|last_id| {
171 let mut query = BTreeMap::new();
172 query.insert("after".into(), Some(last_id.clone()));
173 query
174 })
175 .unwrap_or_default();
176 let page_value = CursorPage::from(ListEnvelope {
177 object,
178 data,
179 first_id,
180 last_id,
181 has_more,
182 extra,
183 });
184 Ok(page_value.with_next_request(if has_more {
185 Some(PageRequestSpec {
186 client: self.clone(),
187 endpoint_id: page.endpoint_id,
188 method,
189 path: page.path,
190 query: next_query,
191 })
192 } else {
193 None
194 }))
195 }
196
197 pub fn completions(&self) -> CompletionsResource {
199 CompletionsResource::new(self.clone())
200 }
201
202 pub fn chat(&self) -> ChatResource {
204 ChatResource::new(self.clone())
205 }
206
207 pub fn embeddings(&self) -> EmbeddingsResource {
209 EmbeddingsResource::new(self.clone())
210 }
211
212 pub fn files(&self) -> FilesResource {
214 FilesResource::new(self.clone())
215 }
216
217 pub fn images(&self) -> ImagesResource {
219 ImagesResource::new(self.clone())
220 }
221
222 pub fn audio(&self) -> AudioResource {
224 AudioResource::new(self.clone())
225 }
226
227 pub fn moderations(&self) -> ModerationsResource {
229 ModerationsResource::new(self.clone())
230 }
231
232 pub fn models(&self) -> ModelsResource {
234 ModelsResource::new(self.clone())
235 }
236
237 pub fn fine_tuning(&self) -> FineTuningResource {
239 FineTuningResource::new(self.clone())
240 }
241
242 pub fn graders(&self) -> GradersResource {
244 GradersResource::new(self.clone())
245 }
246
247 pub fn vector_stores(&self) -> VectorStoresResource {
249 VectorStoresResource::new(self.clone())
250 }
251
252 pub fn webhooks(&self) -> WebhooksResource {
254 WebhooksResource::new(self.clone())
255 }
256
257 pub fn batches(&self) -> BatchesResource {
259 BatchesResource::new(self.clone())
260 }
261
262 pub fn uploads(&self) -> UploadsResource {
264 UploadsResource::new(self.clone())
265 }
266
267 pub fn responses(&self) -> ResponsesResource {
269 ResponsesResource::new(self.clone())
270 }
271
272 pub fn realtime(&self) -> RealtimeResource {
274 RealtimeResource::new(self.clone())
275 }
276
277 pub fn conversations(&self) -> ConversationsResource {
279 ConversationsResource::new(self.clone())
280 }
281
282 pub fn evals(&self) -> EvalsResource {
284 EvalsResource::new(self.clone())
285 }
286
287 pub fn containers(&self) -> ContainersResource {
289 ContainersResource::new(self.clone())
290 }
291
292 pub fn skills(&self) -> SkillsResource {
294 SkillsResource::new(self.clone())
295 }
296
297 pub fn videos(&self) -> VideosResource {
299 VideosResource::new(self.clone())
300 }
301
302 pub fn beta(&self) -> BetaResource {
304 BetaResource::new(self.clone())
305 }
306}
307
308impl ClientInner {
309 pub(crate) fn base_url(&self) -> &str {
310 self.options
311 .base_url
312 .as_deref()
313 .unwrap_or_else(|| self.provider.default_base_url())
314 }
315
316 pub(crate) fn log(
317 &self,
318 level: LogLevel,
319 target: &'static str,
320 message: impl Into<String>,
321 fields: BTreeMap<String, String>,
322 ) {
323 if !self.options.log_level.allows(level) {
324 return;
325 }
326
327 let record = LogRecord {
328 level,
329 target,
330 message: message.into(),
331 fields,
332 };
333
334 if let Some(logger) = &self.options.logger {
335 logger.log(&record);
336 }
337
338 let rendered_fields = if record.fields.is_empty() {
339 String::new()
340 } else {
341 format!(
342 " {}",
343 record
344 .fields
345 .iter()
346 .map(|(key, value)| format!("{key}={value}"))
347 .collect::<Vec<_>>()
348 .join(" ")
349 )
350 };
351 let rendered = format!("[{}] {}{}", target, record.message, rendered_fields);
352 match level {
353 LogLevel::Off => {}
354 LogLevel::Error => error!("{rendered}"),
355 LogLevel::Warn => warn!("{rendered}"),
356 LogLevel::Info => info!("{rendered}"),
357 LogLevel::Debug => debug!("{rendered}"),
358 }
359 }
360}
361
362impl ClientBuilder {
363 pub fn from_env() -> Self {
365 let mut builder = Self::default();
366
367 if let Some(webhook_secret) = read_env("OPENAI_WEBHOOK_SECRET") {
368 builder.options.webhook_secret = Some(SecretString::new(webhook_secret.into()));
369 }
370 if let Some(log_level) =
371 read_env("OPENAI_LOG").and_then(|value| value.parse::<LogLevel>().ok())
372 {
373 builder.options.log_level = log_level;
374 }
375
376 if let Some(azure_endpoint) = read_env("AZURE_OPENAI_ENDPOINT") {
377 builder = builder.azure_endpoint(azure_endpoint);
378 if let Some(api_version) = read_env("OPENAI_API_VERSION") {
379 builder = builder.azure_api_version(api_version);
380 }
381 if let Some(api_key) = read_env("AZURE_OPENAI_API_KEY") {
382 builder = builder.api_key(api_key);
383 }
384 return builder;
385 }
386
387 if let Some(base_url) = read_env("OPENAI_BASE_URL") {
388 builder.options.base_url = Some(base_url);
389 }
390 if let Some(api_key) = read_env("OPENAI_API_KEY") {
391 builder.api_key_source = Some(ApiKeySource::from_static(api_key));
392 }
393
394 builder
395 }
396
397 pub fn provider(mut self, provider: Provider) -> Self {
399 if provider.kind() != ProviderKind::Azure {
400 self.azure_options = AzureOptions::default();
401 self.azure_endpoint = None;
402 self.azure_configured = false;
403 }
404 self.options.provider = provider;
405 self
406 }
407
408 pub fn http_client(mut self, client: reqwest::Client) -> Self {
410 self.http_client = Some(client);
411 self
412 }
413
414 pub fn log_level(mut self, log_level: LogLevel) -> Self {
416 self.options.log_level = log_level;
417 self
418 }
419
420 pub fn logger<L>(mut self, logger: L) -> Self
422 where
423 L: Logger + 'static,
424 {
425 self.options.logger = Some(LoggerHandle::new(logger));
426 self
427 }
428
429 pub fn api_key<T>(mut self, api_key: T) -> Self
431 where
432 T: Into<String>,
433 {
434 self.api_key_source = Some(ApiKeySource::from_static(api_key));
435 self
436 }
437
438 pub fn api_key_provider<F>(mut self, provider: F) -> Self
440 where
441 F: Fn() -> Result<SecretString> + Send + Sync + 'static,
442 {
443 self.api_key_source = Some(ApiKeySource::from_provider(provider));
444 self
445 }
446
447 pub fn api_key_async_provider<F, Fut>(mut self, provider: F) -> Self
449 where
450 F: Fn() -> Fut + Send + Sync + 'static,
451 Fut: Future<Output = Result<SecretString>> + Send + 'static,
452 {
453 self.api_key_source = Some(ApiKeySource::from_async_provider(provider));
454 self
455 }
456
457 pub fn base_url<T>(mut self, base_url: T) -> Self
459 where
460 T: Into<String>,
461 {
462 self.options.base_url = Some(base_url.into());
463 self
464 }
465
466 pub fn disable_proxy_for_local_base_url(mut self, disable: bool) -> Self {
470 self.options.disable_proxy_for_local_base_url = disable;
471 self
472 }
473
474 pub fn azure_endpoint<T>(mut self, endpoint: T) -> Self
479 where
480 T: Into<String>,
481 {
482 self.azure_endpoint = Some(endpoint.into());
483 self.azure_configured = true;
484 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
485 self
486 }
487
488 pub fn azure_api_version<T>(mut self, api_version: T) -> Self
490 where
491 T: Into<String>,
492 {
493 self.azure_options.api_version = Some(api_version.into());
494 self.azure_configured = true;
495 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
496 self
497 }
498
499 pub fn azure_deployment<T>(mut self, deployment: T) -> Self
501 where
502 T: Into<String>,
503 {
504 self.azure_options.deployment = Some(deployment.into());
505 self.azure_configured = true;
506 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
507 self
508 }
509
510 pub fn azure_bearer_auth(mut self) -> Self {
512 self.azure_options = self.azure_options.bearer_auth();
513 self.azure_configured = true;
514 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
515 self
516 }
517
518 pub fn azure_ad_token<T>(mut self, token: T) -> Self
520 where
521 T: Into<String>,
522 {
523 self.azure_options = self.azure_options.bearer_auth();
524 self.azure_configured = true;
525 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
526 self.api_key_source = Some(ApiKeySource::from_static(token));
527 self
528 }
529
530 pub fn azure_ad_token_provider<F, Fut>(mut self, provider: F) -> Self
532 where
533 F: Fn() -> Fut + Send + Sync + 'static,
534 Fut: Future<Output = Result<SecretString>> + Send + 'static,
535 {
536 self.azure_options = self.azure_options.bearer_auth();
537 self.azure_configured = true;
538 self.options.provider = Provider::azure_with_options(self.azure_options.clone());
539 self.api_key_source = Some(ApiKeySource::from_async_provider(provider));
540 self
541 }
542
543 pub fn timeout(mut self, timeout: Duration) -> Self {
545 self.options.timeout = timeout;
546 self
547 }
548
549 pub fn max_retries(mut self, max_retries: u32) -> Self {
551 self.options.max_retries = max_retries;
552 self
553 }
554
555 pub fn default_header<T, U>(mut self, key: T, value: U) -> Self
557 where
558 T: Into<String>,
559 U: Into<String>,
560 {
561 self.options
562 .default_headers
563 .insert(key.into(), value.into());
564 self
565 }
566
567 pub fn default_headers(mut self, headers: BTreeMap<String, String>) -> Self {
569 self.options.default_headers = headers;
570 self
571 }
572
573 pub fn default_query<T, U>(mut self, key: T, value: U) -> Self
575 where
576 T: Into<String>,
577 U: Into<String>,
578 {
579 self.options.default_query.insert(key.into(), value.into());
580 self
581 }
582
583 pub fn default_query_map(mut self, query: BTreeMap<String, String>) -> Self {
585 self.options.default_query = query;
586 self
587 }
588
589 pub fn webhook_secret<T>(mut self, secret: T) -> Self
591 where
592 T: Into<String>,
593 {
594 self.options.webhook_secret = Some(SecretString::new(secret.into().into()));
595 self
596 }
597
598 pub fn compatibility_mode(mut self, mode: CompatibilityMode) -> Self {
600 self.options.compatibility_mode = mode;
601 self
602 }
603
604 pub fn build(self) -> Result<Client> {
610 let mut options = self.options;
611 if options.provider.kind() == ProviderKind::Azure
612 && (self.azure_configured || self.azure_endpoint.is_some())
613 {
614 options.provider = Provider::azure_with_options(self.azure_options.clone());
615 if let Some(endpoint) = self.azure_endpoint {
616 if options.base_url.is_some() {
617 return Err(Error::InvalidConfig(
618 "`base_url` 和 `azure_endpoint` 不能同时设置".into(),
619 ));
620 }
621 options.base_url = Some(endpoint);
622 }
623 }
624
625 let http = if let Some(client) = self.http_client {
626 client
627 } else {
628 let mut default_headers = reqwest::header::HeaderMap::new();
629 default_headers.insert(
630 reqwest::header::USER_AGENT,
631 reqwest::header::HeaderValue::from_static(concat!(
632 env!("CARGO_PKG_NAME"),
633 "/",
634 env!("CARGO_PKG_VERSION")
635 )),
636 );
637
638 let mut builder = reqwest::Client::builder().default_headers(default_headers);
639 if options.disable_proxy_for_local_base_url
640 && should_disable_proxy_for_base_url(options.base_url.as_deref())
641 {
642 builder = builder.no_proxy();
643 }
644
645 builder
646 .build()
647 .map_err(|error| Error::InvalidConfig(format!("创建 HTTP 客户端失败: {error}")))?
648 };
649 Ok(Client::from_parts(
650 http,
651 options.provider.clone(),
652 self.api_key_source,
653 options,
654 ))
655 }
656}
657
658fn read_env(key: &str) -> Option<String> {
659 env::var(key)
660 .ok()
661 .map(|value| value.trim().to_owned())
662 .filter(|value| !value.is_empty())
663}
664
665fn should_disable_proxy_for_base_url(base_url: Option<&str>) -> bool {
666 let Some(base_url) = base_url else {
667 return false;
668 };
669
670 let Ok(url) = url::Url::parse(base_url) else {
671 return false;
672 };
673
674 matches!(
675 url.host_str(),
676 Some("localhost") | Some("127.0.0.1") | Some("[::1]") | Some("::1")
677 )
678}