ai_provider_sdk/
client.rs1use std::collections::HashMap;
4use std::env;
5use std::sync::Arc;
6use std::time::Duration;
7
8use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
9use url::Url;
10
11use crate::error::{Error, Result};
12use crate::resources::{Chat, Embeddings, Files, Models, Moderations, Responses};
13use crate::transport::Transport;
14
15const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
16const DEFAULT_TIMEOUT: Duration = Duration::from_secs(600);
17const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
18const DEFAULT_MAX_RETRIES: u32 = 2;
19
20#[derive(Debug, Clone)]
21pub struct ClientOptions {
27 pub api_key: Option<String>,
28 pub organization: Option<String>,
29 pub project: Option<String>,
30 pub base_url: Option<String>,
31 pub timeout: Duration,
32 pub max_retries: u32,
33 pub default_headers: HashMap<String, String>,
34 pub default_query: HashMap<String, String>,
35}
36
37impl Default for ClientOptions {
38 fn default() -> Self {
39 Self {
40 api_key: None,
41 organization: None,
42 project: None,
43 base_url: None,
44 timeout: DEFAULT_TIMEOUT,
45 max_retries: DEFAULT_MAX_RETRIES,
46 default_headers: HashMap::new(),
47 default_query: HashMap::new(),
48 }
49 }
50}
51
52#[derive(Clone)]
53pub struct OpenAI {
54 pub(crate) inner: Arc<Transport>,
55}
56
57impl OpenAI {
58 pub fn new(api_key: impl Into<String>) -> Result<Self> {
60 Self::with_options(ClientOptions {
61 api_key: Some(api_key.into()),
62 ..ClientOptions::default()
63 })
64 }
65
66 pub fn from_env() -> Result<Self> {
68 Self::with_options(ClientOptions::default())
69 }
70
71 pub fn with_options(mut options: ClientOptions) -> Result<Self> {
73 let api_key = options
74 .api_key
75 .take()
76 .or_else(|| env::var("OPENAI_API_KEY").ok())
77 .ok_or_else(|| {
78 Error::Config("api_key must be provided or OPENAI_API_KEY must be set".to_string())
79 })?;
80
81 if options.organization.is_none() {
82 options.organization = env::var("OPENAI_ORG_ID").ok();
83 }
84 if options.project.is_none() {
85 options.project = env::var("OPENAI_PROJECT_ID").ok();
86 }
87
88 let base_url = options
89 .base_url
90 .take()
91 .or_else(|| env::var("OPENAI_BASE_URL").ok())
92 .unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
93
94 let base_url = normalize_base_url(&base_url)?;
95 let headers = build_default_headers(&api_key, &options)?;
96 let http = reqwest::Client::builder()
97 .timeout(options.timeout)
98 .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
99 .build()
100 .map_err(|err| Error::Connection(err.to_string()))?;
101
102 Ok(Self {
103 inner: Arc::new(Transport::new(
104 http,
105 base_url,
106 headers,
107 options.default_query,
108 options.max_retries,
109 )),
110 })
111 }
112
113 pub fn responses(&self) -> Responses {
114 Responses::new(self.inner.clone())
115 }
116
117 pub fn chat(&self) -> Chat {
118 Chat::new(self.inner.clone())
119 }
120
121 pub fn models(&self) -> Models {
122 Models::new(self.inner.clone())
123 }
124
125 pub fn embeddings(&self) -> Embeddings {
126 Embeddings::new(self.inner.clone())
127 }
128
129 pub fn files(&self) -> Files {
130 Files::new(self.inner.clone())
131 }
132
133 pub fn moderations(&self) -> Moderations {
134 Moderations::new(self.inner.clone())
135 }
136}
137
138fn normalize_base_url(base_url: &str) -> Result<Url> {
139 let mut url = Url::parse(base_url)?;
140 if !url.path().ends_with('/') {
141 let path = format!("{}/", url.path().trim_end_matches('/'));
142 url.set_path(&path);
143 }
144 Ok(url)
145}
146
147fn build_default_headers(api_key: &str, options: &ClientOptions) -> Result<HeaderMap> {
148 let mut headers = HeaderMap::new();
149 headers.insert(
150 AUTHORIZATION,
151 HeaderValue::from_str(&format!("Bearer {api_key}"))?,
152 );
153 headers.insert("x-stainless-async", HeaderValue::from_static("true"));
154 headers.insert("content-type", HeaderValue::from_static("application/json"));
155
156 if let Some(organization) = &options.organization {
157 headers.insert("openai-organization", HeaderValue::from_str(organization)?);
158 }
159 if let Some(project) = &options.project {
160 headers.insert("openai-project", HeaderValue::from_str(project)?);
161 }
162
163 for (key, value) in &options.default_headers {
164 let name = HeaderName::from_bytes(key.as_bytes())
165 .map_err(|err| Error::Config(format!("invalid header name `{key}`: {err}")))?;
166 headers.insert(name, HeaderValue::from_str(value)?);
167 }
168
169 Ok(headers)
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn normalizes_base_url_with_trailing_slash() {
178 let url = normalize_base_url("https://api.example.com/v1").unwrap();
179 assert_eq!(url.as_str(), "https://api.example.com/v1/");
180 }
181}