1use crate::auth::*;
2use crate::error::*;
3use crate::proto::*;
4use http::HeaderName;
5use http::HeaderValue;
6use http::Method;
7use reqwest::multipart::Form;
8use reqwest::Body;
9use reqwest::Response;
10use smart_default::SmartDefault;
11use std::time::Duration;
12use sys::ModelListResponse;
13use tracing::*;
14use url::Url;
15
16#[derive(SmartDefault)]
28pub struct ClientBuilder {
29 pub base_url: Option<Url>,
30 pub authenticator: Option<Box<dyn AuthenticatorTrait>>,
31}
32
33impl ClientBuilder {
34 pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Result<Self> {
36 let base_url = Url::parse(base_url.as_ref())?;
37 self.base_url = Some(base_url);
38 Ok(self)
39 }
40
41 pub fn with_version(mut self, version: impl AsRef<str>) -> Result<Self> {
43 let base_url = self
44 .base_url
45 .as_mut()
46 .ok_or(Error::ClientBuild)?
47 .join(version.as_ref())?;
48 self.base_url = Some(base_url);
49 Ok(self)
50 }
51
52 pub fn with_key(self, key: impl AsRef<str>) -> Result<Self> {
54 self.with_authenticator(Bearer::new(key.as_ref().to_string()))
55 }
56
57 pub fn with_authenticator(
59 mut self,
60 authenticator: impl AuthenticatorTrait + 'static,
61 ) -> Result<Self> {
62 self.authenticator = Some(Box::new(authenticator));
63 Ok(self)
64 }
65
66 pub fn build(self) -> Result<Client> {
68 let Self {
69 base_url,
70 authenticator,
71 } = self;
72
73 let base_url = base_url.ok_or(Error::ClientBuild)?;
74
75 let authenticator = authenticator.ok_or(Error::ClientBuild)?;
76
77 Ok(Client {
78 base_url,
79 authenticator,
80 client: reqwest::Client::new(),
81 })
82 }
83}
84
85pub struct Client {
87 base_url: Url,
88 authenticator: Box<dyn AuthenticatorTrait>,
89 client: reqwest::Client,
90}
91
92impl Client {
93 pub fn from_env_file(env: impl AsRef<str>) -> Result<Self> {
95 let _ = dotenv::from_filename(env.as_ref());
96 Self::from_env()
97 }
98
99 pub fn from_default_env() -> Result<Self> {
101 let _ = dotenv::dotenv();
102 Self::from_env()
103 }
104
105 pub fn from_env() -> Result<Self> {
107 let base_url = std::env::var("OPENAI_API_BASE_URL")?;
108 let key = std::env::var("OPENAI_API_KEY")?;
109 let version = std::env::var("OPENAI_API_VERSION")?;
110 Self::builder()
111 .with_base_url(base_url)?
112 .with_version(version)?
113 .with_authenticator(Bearer::new(key))?
114 .build()
115 }
116
117 pub fn builder() -> ClientBuilder {
119 ClientBuilder::default()
120 }
121
122 pub async fn models(&self, timeout: Option<Duration>) -> Result<ModelListResponse> {
124 let rep = self
125 .call_impl(Method::GET, "models", [], None, None, timeout)
126 .await?;
127
128 let status = rep.status();
129
130 let rep: serde_json::Value = serde_json::from_slice(rep.bytes().await?.as_ref())?;
131
132 for l in serde_json::to_string_pretty(&rep)?.lines() {
133 if status.is_client_error() || status.is_server_error() {
134 error!("REP: {}", l);
135 } else {
136 trace!("REP: {}", l);
137 }
138 }
139
140 if !status.is_success() {
141 return Err(Error::ApiError(status.as_u16()));
142 }
143
144 Ok(serde_json::from_value(rep)?)
145 }
146
147 pub async fn call_impl(
149 &self,
150 method: Method,
151 uri: impl AsRef<str>,
152 headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
153 body: Option<Body>,
154 form: Option<Form>,
155 timeout: Option<Duration>,
156 ) -> Result<Response> {
157 let path = std::path::PathBuf::from(self.base_url.path()).join(uri.as_ref());
158
159 let url = self.base_url.join(path.to_str().expect("?"))?;
160
161 let mut builder = self.client.request(method, url);
162
163 if let Some(timeout) = timeout {
164 builder = builder.timeout(timeout);
165 }
166
167 for (k, v) in headers.into_iter() {
168 builder = builder.header(k, v);
169 }
170
171 if let Some(body) = body {
172 builder = builder.body(body);
173 }
174
175 if let Some(form) = form {
176 builder = builder.multipart(form);
177 }
178
179 let mut req = builder.build()?;
180
181 self.authenticator.authorize(&mut req).await?;
182
183 let rep = self.client.execute(req).await?; Ok(rep)
186 }
187}