openai_ng/
client.rs

1use crate::auth::*;
2use crate::error::*;
3use crate::proto::*;
4use http::HeaderName;
5use http::HeaderValue;
6use http::Method;
7use reqwest::multipart::Form;
8use reqwest::Body;
9use reqwest::Response;
10use smart_default::SmartDefault;
11use std::time::Duration;
12use sys::ModelListResponse;
13use tracing::*;
14use url::Url;
15
16/// Client builder
17/// ```rust
18/// use openai_ng::prelude::*;
19///
20/// let builder = Client::builder();
21/// let client = builder
22///                 .with_base_url("https://api.openai.com")?
23///                 .with_version("v1")?
24///                 .with_key("you client key")?
25///                 .build()?;
26/// ```
27#[derive(SmartDefault)]
28pub struct ClientBuilder {
29    pub base_url: Option<Url>,
30    pub authenticator: Option<Box<dyn AuthenticatorTrait>>,
31}
32
33impl ClientBuilder {
34    /// config base_url
35    pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Result<Self> {
36        let base_url = Url::parse(base_url.as_ref())?;
37        self.base_url = Some(base_url);
38        Ok(self)
39    }
40
41    /// config version
42    pub fn with_version(mut self, version: impl AsRef<str>) -> Result<Self> {
43        let base_url = self
44            .base_url
45            .as_mut()
46            .ok_or(Error::ClientBuild)?
47            .join(version.as_ref())?;
48        self.base_url = Some(base_url);
49        Ok(self)
50    }
51
52    /// config bearer authenticator with key
53    pub fn with_key(self, key: impl AsRef<str>) -> Result<Self> {
54        self.with_authenticator(Bearer::new(key.as_ref().to_string()))
55    }
56
57    /// config authenticator with custom authenticator
58    pub fn with_authenticator(
59        mut self,
60        authenticator: impl AuthenticatorTrait + 'static,
61    ) -> Result<Self> {
62        self.authenticator = Some(Box::new(authenticator));
63        Ok(self)
64    }
65
66    /// build client
67    pub fn build(self) -> Result<Client> {
68        let Self {
69            base_url,
70            authenticator,
71        } = self;
72
73        let base_url = base_url.ok_or(Error::ClientBuild)?;
74
75        let authenticator = authenticator.ok_or(Error::ClientBuild)?;
76
77        Ok(Client {
78            base_url,
79            authenticator,
80            client: reqwest::Client::new(),
81        })
82    }
83}
84
85/// OpenAI API client
86pub struct Client {
87    base_url: Url,
88    authenticator: Box<dyn AuthenticatorTrait>,
89    client: reqwest::Client,
90}
91
92impl Client {
93    /// create client from customized env file, convenient for development, use `dotenv` crate
94    pub fn from_env_file(env: impl AsRef<str>) -> Result<Self> {
95        let _ = dotenv::from_filename(env.as_ref());
96        Self::from_env()
97    }
98
99    /// create client from default env file: `.env`, convenient for development, use `dotenv` crate
100    pub fn from_default_env() -> Result<Self> {
101        let _ = dotenv::dotenv();
102        Self::from_env()
103    }
104
105    /// create client from environment variables
106    pub fn from_env() -> Result<Self> {
107        let base_url = std::env::var("OPENAI_API_BASE_URL")?;
108        let key = std::env::var("OPENAI_API_KEY")?;
109        let version = std::env::var("OPENAI_API_VERSION")?;
110        Self::builder()
111            .with_base_url(base_url)?
112            .with_version(version)?
113            .with_authenticator(Bearer::new(key))?
114            .build()
115    }
116
117    /// create a client builder
118    pub fn builder() -> ClientBuilder {
119        ClientBuilder::default()
120    }
121
122    /// list all models available
123    pub async fn models(&self, timeout: Option<Duration>) -> Result<ModelListResponse> {
124        let rep = self
125            .call_impl(Method::GET, "models", [], None, None, timeout)
126            .await?;
127
128        let status = rep.status();
129
130        let rep: serde_json::Value = serde_json::from_slice(rep.bytes().await?.as_ref())?;
131
132        for l in serde_json::to_string_pretty(&rep)?.lines() {
133            if status.is_client_error() || status.is_server_error() {
134                error!("REP: {}", l);
135            } else {
136                trace!("REP: {}", l);
137            }
138        }
139
140        if !status.is_success() {
141            return Err(Error::ApiError(status.as_u16()));
142        }
143
144        Ok(serde_json::from_value(rep)?)
145    }
146
147    /// do the actual call
148    pub async fn call_impl(
149        &self,
150        method: Method,
151        uri: impl AsRef<str>,
152        headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
153        body: Option<Body>,
154        form: Option<Form>,
155        timeout: Option<Duration>,
156    ) -> Result<Response> {
157        let path = std::path::PathBuf::from(self.base_url.path()).join(uri.as_ref());
158
159        let url = self.base_url.join(path.to_str().expect("?"))?;
160
161        let mut builder = self.client.request(method, url);
162
163        if let Some(timeout) = timeout {
164            builder = builder.timeout(timeout);
165        }
166
167        for (k, v) in headers.into_iter() {
168            builder = builder.header(k, v);
169        }
170
171        if let Some(body) = body {
172            builder = builder.body(body);
173        }
174
175        if let Some(form) = form {
176            builder = builder.multipart(form);
177        }
178
179        let mut req = builder.build()?;
180
181        self.authenticator.authorize(&mut req).await?;
182
183        let rep = self.client.execute(req).await?; //.error_for_status()?;
184
185        Ok(rep)
186    }
187}