1#[macro_use]
3extern crate derive_builder;
4
5use thiserror::Error;
6
7type Result<T> = std::result::Result<T, Error>;
8
9#[allow(clippy::default_trait_access)]
10pub mod api {
11 use std::{collections::HashMap, convert::TryFrom, fmt::Display};
13
14 use serde::{Deserialize, Serialize};
15
16 #[derive(Deserialize, Debug)]
18 pub(crate) struct Container<T> {
19 pub data: Vec<T>,
21 }
22
23 #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
25 pub struct EngineInfo {
26 pub id: String,
28 pub owner: String,
30 pub ready: bool,
32 }
33
34 #[derive(Serialize, Debug, Builder, Clone)]
36 #[builder(pattern = "immutable")]
37 pub struct CompletionArgs {
38 #[builder(setter(into), default = "\"davinci\".into()")]
46 #[serde(skip_serializing)]
47 pub(super) engine: String,
48 #[builder(setter(into), default = "\"<|endoftext|>\".into()")]
58 prompt: String,
59 #[builder(default = "16")]
68 max_tokens: u64,
69 #[builder(default = "1.0")]
90 temperature: f64,
91 #[builder(default = "1.0")]
92 top_p: f64,
93 #[builder(default = "1")]
94 n: u64,
95 #[builder(setter(strip_option), default)]
96 logprobs: Option<u64>,
97 #[builder(default = "false")]
98 echo: bool,
99 #[builder(setter(strip_option), default)]
100 stop: Option<Vec<String>>,
101 #[builder(default = "0.0")]
102 presence_penalty: f64,
103 #[builder(default = "0.0")]
104 frequency_penalty: f64,
105 #[builder(default)]
106 logit_bias: HashMap<String, f64>,
107 }
108
109 impl From<&str> for CompletionArgs {
112 fn from(prompt_string: &str) -> Self {
113 Self {
114 prompt: prompt_string.into(),
115 ..CompletionArgsBuilder::default()
116 .build()
117 .expect("default should build")
118 }
119 }
120 }
121
122 impl CompletionArgs {
123 #[must_use]
125 pub fn builder() -> CompletionArgsBuilder {
126 CompletionArgsBuilder::default()
127 }
128 }
129
130 impl TryFrom<CompletionArgsBuilder> for CompletionArgs {
131 type Error = CompletionArgsBuilderError;
132
133 fn try_from(builder: CompletionArgsBuilder) -> Result<Self, Self::Error> {
134 builder.build()
135 }
136 }
137
138 #[derive(Deserialize, Debug, Clone)]
140 pub struct Completion {
141 pub id: String,
143 pub created: u64,
145 pub model: String,
147 pub choices: Vec<Choice>,
149 }
150
151 impl std::fmt::Display for Completion {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(f, "{}", self.choices[0])
154 }
155 }
156
157 #[derive(Deserialize, Debug, Clone)]
159 pub struct Choice {
160 pub text: String,
162 pub index: u64,
164 pub logprobs: Option<LogProbs>,
166 pub finish_reason: String,
168 }
169
170 impl std::fmt::Display for Choice {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 self.text.fmt(f)
173 }
174 }
175
176 #[derive(Deserialize, Debug, Clone)]
178 pub struct LogProbs {
179 pub tokens: Vec<String>,
180 pub token_logprobs: Vec<Option<f64>>,
181 pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
182 pub text_offset: Vec<u64>,
183 }
184
185 #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
187 pub struct ErrorMessage {
188 pub message: String,
189 #[serde(rename = "type")]
190 pub error_type: String,
191 }
192
193 impl Display for ErrorMessage {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 self.message.fmt(f)
196 }
197 }
198
199 #[derive(Deserialize, Debug)]
201 pub(crate) struct ErrorWrapper {
202 pub error: ErrorMessage,
203 }
204}
205
206#[derive(Error, Debug)]
208pub enum Error {
209 #[error("API returned an Error: {}", .0.message)]
211 Api(api::ErrorMessage),
212 #[error("Bad arguments: {0}")]
214 BadArguments(String),
215 #[error("Error at the protocol level: {0}")]
217 AsyncProtocol(reqwest::Error),
218}
219
220impl From<api::ErrorMessage> for Error {
221 fn from(e: api::ErrorMessage) -> Self {
222 Error::Api(e)
223 }
224}
225
226impl From<String> for Error {
227 fn from(e: String) -> Self {
228 Error::BadArguments(e)
229 }
230}
231
232impl From<reqwest::Error> for Error {
233 fn from(e: reqwest::Error) -> Self {
234 Error::AsyncProtocol(e)
235 }
236}
237
238struct BearerToken {
240 token: String,
241}
242
243impl std::fmt::Debug for BearerToken {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 write!(
247 f,
248 r#"Bearer {{ token: "{}" }}"#,
249 self.token.get(0..8).ok_or(std::fmt::Error)?
250 )
251 }
252}
253
254impl BearerToken {
255 fn new(token: &str) -> Self {
256 Self {
257 token: String::from(token),
258 }
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct Client {
265 client: reqwest::Client,
266 base_url: String,
267 token: String,
268}
269
270impl Client {
271 #[must_use]
273 pub fn new(token: &str) -> Self {
274 Self {
275 client: reqwest::Client::new(),
276 base_url: "https://api.openai.com/v1/".to_string(),
277 token: token.to_string(),
278 }
279 }
280
281 async fn get<T>(&self, endpoint: &str) -> Result<T>
283 where
284 T: serde::de::DeserializeOwned,
285 {
286 let mut response =
287 self.client
288 .get(endpoint)
289 .header("Authorization", format!("Bearer {}", self.token))
290 .send()
291 .await?;
292
293 if let reqwest::StatusCode::OK = response.status() {
294 Ok(response.json::<T>().await?)
295 } else {
296 let err = response.json::<api::ErrorWrapper>().await?.error;
297 Err(Error::Api(err))
298 }
299 }
300
301 pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
308 self.get(
309 &self.build_url_from_path(
310 &format!(
311 "engines",
312 ),
313 ),
314 ).await.map(|r: api::Container<_>| r.data)
315 }
316
317 pub async fn engine(&self, engine: &str) -> Result<api::EngineInfo> {
324 self.get(
325 &self.build_url_from_path(
326 &format!(
327 "engines/{}",
328 engine,
329 ),
330 ),
331 ).await
332 }
333
334 async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
337 where
338 B: serde::ser::Serialize,
339 R: serde::de::DeserializeOwned,
340 {
341 let mut response = self
342 .client
343 .post(endpoint)
344 .header("Authorization", format!("Bearer {}", self.token))
345 .json(&body)
346 .send()
347 .await?;
348
349 match response.status() {
350 reqwest::StatusCode::OK => Ok(response.json::<R>().await?),
351 _ => Err(Error::Api(
352 response
353 .json::<api::ErrorWrapper>()
354 .await
355 .expect("The API has returned something funky")
356 .error,
357 )),
358 }
359 }
360
361 pub fn build_url_from_path(&self, path: &str) -> String {
363 format!("{}{}", self.base_url, path)
364 }
365
366 pub async fn complete_prompt(
371 &self,
372 prompt: impl Into<api::CompletionArgs>,
373 ) -> Result<api::Completion> {
374 let args = prompt.into();
375 Ok(self
376 .post(
377 &self.build_url_from_path(
378 &format!(
379 "engines/{}/completions",
380 args.engine,
381 )
382 ),
383 args,
384 )
385 .await?)
386 }
387}