Skip to main content

ai_provider_sdk/
client.rs

1//! 客户端构建与资源访问入口。负责配置归一化、默认头注入与 Transport 装配。
2
3use std::collections::HashMap;
4use std::env;
5use std::sync::Arc;
6use std::time::Duration;
7
8use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
9use url::Url;
10
11use crate::error::{Error, Result};
12use crate::resources::{Chat, Embeddings, Files, Models, Moderations, Responses};
13use crate::transport::Transport;
14
15const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
16const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
17const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
18const DEFAULT_MAX_RETRIES: u32 = 2;
19
20#[derive(Debug, Clone)]
21/// 客户端初始化选项。
22///
23/// 边界约束:
24/// - `api_key` 可通过显式传入或环境变量 `OPENAI_API_KEY` 提供。
25/// - `default_headers` 与 `default_query` 会应用于每次请求。
26pub struct ClientOptions {
27    pub api_key: Option<String>,
28    pub organization: Option<String>,
29    pub project: Option<String>,
30    pub base_url: Option<String>,
31    pub timeout: Duration,
32    pub max_retries: u32,
33    pub default_headers: HashMap<String, String>,
34    pub default_query: HashMap<String, String>,
35}
36
37impl Default for ClientOptions {
38    fn default() -> Self {
39        Self {
40            api_key: None,
41            organization: None,
42            project: None,
43            base_url: None,
44            timeout: DEFAULT_TIMEOUT,
45            max_retries: DEFAULT_MAX_RETRIES,
46            default_headers: HashMap::new(),
47            default_query: HashMap::new(),
48        }
49    }
50}
51
52#[derive(Clone)]
53pub struct OpenAI {
54    pub(crate) inner: Arc<Transport>,
55}
56
57impl OpenAI {
58    /// 使用显式 API Key 创建客户端。
59    pub fn new(api_key: impl Into<String>) -> Result<Self> {
60        Self::with_options(ClientOptions {
61            api_key: Some(api_key.into()),
62            ..ClientOptions::default()
63        })
64    }
65
66    /// 仅从环境变量读取配置创建客户端。
67    pub fn from_env() -> Result<Self> {
68        Self::with_options(ClientOptions::default())
69    }
70
71    /// 使用完整选项创建客户端并完成配置归一化。
72    pub fn with_options(mut options: ClientOptions) -> Result<Self> {
73        let api_key = options
74            .api_key
75            .take()
76            .or_else(|| env::var("OPENAI_API_KEY").ok())
77            .ok_or_else(|| {
78                Error::Config("api_key must be provided or OPENAI_API_KEY must be set".to_string())
79            })?;
80
81        if options.organization.is_none() {
82            options.organization = env::var("OPENAI_ORG_ID").ok();
83        }
84        if options.project.is_none() {
85            options.project = env::var("OPENAI_PROJECT_ID").ok();
86        }
87
88        let base_url = options
89            .base_url
90            .take()
91            .or_else(|| env::var("OPENAI_BASE_URL").ok())
92            .unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
93
94        let base_url = normalize_base_url(&base_url)?;
95        let headers = build_default_headers(&api_key, &options)?;
96        let http = reqwest::Client::builder()
97            .timeout(options.timeout)
98            .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
99            .build()
100            .map_err(|err| Error::Connection(err.to_string()))?;
101
102        Ok(Self {
103            inner: Arc::new(Transport::new(
104                http,
105                base_url,
106                headers,
107                options.default_query,
108                options.max_retries,
109            )),
110        })
111    }
112
113    pub fn responses(&self) -> Responses {
114        Responses::new(self.inner.clone())
115    }
116
117    pub fn chat(&self) -> Chat {
118        Chat::new(self.inner.clone())
119    }
120
121    pub fn models(&self) -> Models {
122        Models::new(self.inner.clone())
123    }
124
125    pub fn embeddings(&self) -> Embeddings {
126        Embeddings::new(self.inner.clone())
127    }
128
129    pub fn files(&self) -> Files {
130        Files::new(self.inner.clone())
131    }
132
133    pub fn moderations(&self) -> Moderations {
134        Moderations::new(self.inner.clone())
135    }
136}
137
138fn normalize_base_url(base_url: &str) -> Result<Url> {
139    let mut url = Url::parse(base_url)?;
140    if !url.path().ends_with('/') {
141        let path = format!("{}/", url.path().trim_end_matches('/'));
142        url.set_path(&path);
143    }
144    Ok(url)
145}
146
147fn build_default_headers(api_key: &str, options: &ClientOptions) -> Result<HeaderMap> {
148    let mut headers = HeaderMap::new();
149    headers.insert(
150        AUTHORIZATION,
151        HeaderValue::from_str(&format!("Bearer {api_key}"))?,
152    );
153    headers.insert("x-stainless-async", HeaderValue::from_static("true"));
154    headers.insert("content-type", HeaderValue::from_static("application/json"));
155
156    if let Some(organization) = &options.organization {
157        headers.insert("openai-organization", HeaderValue::from_str(organization)?);
158    }
159    if let Some(project) = &options.project {
160        headers.insert("openai-project", HeaderValue::from_str(project)?);
161    }
162
163    for (key, value) in &options.default_headers {
164        let name = HeaderName::from_bytes(key.as_bytes())
165            .map_err(|err| Error::Config(format!("invalid header name `{key}`: {err}")))?;
166        headers.insert(name, HeaderValue::from_str(value)?);
167    }
168
169    Ok(headers)
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn normalizes_base_url_with_trailing_slash() {
178        let url = normalize_base_url("https://api.example.com/v1").unwrap();
179        assert_eq!(url.as_str(), "https://api.example.com/v1/");
180    }
181}