openai_api_fork/
lib.rs

1/// `OpenAI` API client library
2#[macro_use]
3extern crate derive_builder;
4
5use thiserror::Error;
6
7type Result<T> = std::result::Result<T, Error>;
8
9#[allow(clippy::default_trait_access)]
10pub mod api {
11    //! Data types corresponding to requests and responses from the API
12    use std::{collections::HashMap, convert::TryFrom, fmt::Display};
13
14    use serde::{Deserialize, Serialize};
15
16    /// Container type. Used in the api, but not useful for clients of this library
17    #[derive(Deserialize, Debug)]
18    pub(crate) struct Container<T> {
19        /// Items in the page's results
20        pub data: Vec<T>,
21    }
22
23    /// Detailed information on a particular engine.
24    #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
25    pub struct EngineInfo {
26        /// The name of the engine, e.g. `"davinci"` or `"ada"`
27        pub id: String,
28        /// The owner of the model. Usually (always?) `"openai"`
29        pub owner: String,
30        /// Whether the model is ready for use. Usually (always?) `true`
31        pub ready: bool,
32    }
33
34    /// Options for the query completion
35    #[derive(Serialize, Debug, Builder, Clone)]
36    #[builder(pattern = "immutable")]
37    pub struct CompletionArgs {
38        /// The id of the engine to use for this request
39        ///
40        /// # Example
41        /// ```
42        /// # use openai_api::api::CompletionArgs;
43        /// CompletionArgs::builder().engine("davinci");
44        /// ```
45        #[builder(setter(into), default = "\"davinci\".into()")]
46        #[serde(skip_serializing)]
47        pub(super) engine: String,
48        /// The prompt to complete from.
49        ///
50        /// Defaults to `"<|endoftext|>"` which is a special token seen during training.
51        ///
52        /// # Example
53        /// ```
54        /// # use openai_api::api::CompletionArgs;
55        /// CompletionArgs::builder().prompt("Once upon a time...");
56        /// ```
57        #[builder(setter(into), default = "\"<|endoftext|>\".into()")]
58        prompt: String,
59        /// Maximum number of tokens to complete.
60        ///
61        /// Defaults to 16
62        /// # Example
63        /// ```
64        /// # use openai_api::api::CompletionArgs;
65        /// CompletionArgs::builder().max_tokens(64);
66        /// ```
67        #[builder(default = "16")]
68        max_tokens: u64,
69        /// What sampling temperature to use.
70        ///
71        /// Default is `1.0`
72        ///
73        /// Higher values means the model will take more risks.
74        /// Try 0.9 for more creative applications, and 0 (argmax sampling)
75        /// for ones with a well-defined answer.
76        ///
77        /// OpenAI recommends altering this or top_p but not both.
78        ///
79        /// # Example
80        /// ```
81        /// # use openai_api::api::{CompletionArgs, CompletionArgsBuilder};
82        /// # use std::convert::{TryInto, TryFrom};
83        /// # fn main() -> Result<(), String> {
84        /// let builder = CompletionArgs::builder().temperature(0.7);
85        /// let args: CompletionArgs = builder.try_into()?;
86        /// # Ok::<(), String>(())
87        /// # }
88        /// ```
89        #[builder(default = "1.0")]
90        temperature: f64,
91        #[builder(default = "1.0")]
92        top_p: f64,
93        #[builder(default = "1")]
94        n: u64,
95        #[builder(setter(strip_option), default)]
96        logprobs: Option<u64>,
97        #[builder(default = "false")]
98        echo: bool,
99        #[builder(setter(strip_option), default)]
100        stop: Option<Vec<String>>,
101        #[builder(default = "0.0")]
102        presence_penalty: f64,
103        #[builder(default = "0.0")]
104        frequency_penalty: f64,
105        #[builder(default)]
106        logit_bias: HashMap<String, f64>,
107    }
108
109    // TODO: add validators for the different arguments
110
111    impl From<&str> for CompletionArgs {
112        fn from(prompt_string: &str) -> Self {
113            Self {
114                prompt: prompt_string.into(),
115                ..CompletionArgsBuilder::default()
116                    .build()
117                    .expect("default should build")
118            }
119        }
120    }
121
122    impl CompletionArgs {
123        /// Build a `CompletionArgs` from the defaults
124        #[must_use]
125        pub fn builder() -> CompletionArgsBuilder {
126            CompletionArgsBuilder::default()
127        }
128    }
129
130    impl TryFrom<CompletionArgsBuilder> for CompletionArgs {
131        type Error = CompletionArgsBuilderError;
132
133        fn try_from(builder: CompletionArgsBuilder) -> Result<Self, Self::Error> {
134            builder.build()
135        }
136    }
137
138    /// Represents a non-streamed completion response
139    #[derive(Deserialize, Debug, Clone)]
140    pub struct Completion {
141        /// Completion unique identifier
142        pub id: String,
143        /// Unix timestamp when the completion was generated
144        pub created: u64,
145        /// Exact model type and version used for the completion
146        pub model: String,
147        /// List of completions generated by the model
148        pub choices: Vec<Choice>,
149    }
150
151    impl std::fmt::Display for Completion {
152        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153            write!(f, "{}", self.choices[0])
154        }
155    }
156
157    /// A single completion result
158    #[derive(Deserialize, Debug, Clone)]
159    pub struct Choice {
160        /// The text of the completion. Will contain the prompt if echo is True.
161        pub text: String,
162        /// Offset in the result where the completion began. Useful if using echo.
163        pub index: u64,
164        /// If requested, the log probabilities of the completion tokens
165        pub logprobs: Option<LogProbs>,
166        /// Why the completion ended when it did
167        pub finish_reason: String,
168    }
169
170    impl std::fmt::Display for Choice {
171        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172            self.text.fmt(f)
173        }
174    }
175
176    /// Represents a logprobs subdocument
177    #[derive(Deserialize, Debug, Clone)]
178    pub struct LogProbs {
179        pub tokens: Vec<String>,
180        pub token_logprobs: Vec<Option<f64>>,
181        pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
182        pub text_offset: Vec<u64>,
183    }
184
185    /// Error response object from the server
186    #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
187    pub struct ErrorMessage {
188        pub message: String,
189        #[serde(rename = "type")]
190        pub error_type: String,
191    }
192
193    impl Display for ErrorMessage {
194        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195            self.message.fmt(f)
196        }
197    }
198
199    /// API-level wrapper used in deserialization
200    #[derive(Deserialize, Debug)]
201    pub(crate) struct ErrorWrapper {
202        pub error: ErrorMessage,
203    }
204}
205
206/// This library's main `Error` type.
207#[derive(Error, Debug)]
208pub enum Error {
209    /// An error returned by the API itself
210    #[error("API returned an Error: {}", .0.message)]
211    Api(api::ErrorMessage),
212    /// An error the client discovers before talking to the API
213    #[error("Bad arguments: {0}")]
214    BadArguments(String),
215    /// Network / protocol related errors
216    #[error("Error at the protocol level: {0}")]
217    AsyncProtocol(reqwest::Error),
218}
219
220impl From<api::ErrorMessage> for Error {
221    fn from(e: api::ErrorMessage) -> Self {
222        Error::Api(e)
223    }
224}
225
226impl From<String> for Error {
227    fn from(e: String) -> Self {
228        Error::BadArguments(e)
229    }
230}
231
232impl From<reqwest::Error> for Error {
233    fn from(e: reqwest::Error) -> Self {
234        Error::AsyncProtocol(e)
235    }
236}
237
238/// Authentication middleware
239struct BearerToken {
240    token: String,
241}
242
243impl std::fmt::Debug for BearerToken {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        // Get the first few characters to help debug, but not accidentally log key
246        write!(
247            f,
248            r#"Bearer {{ token: "{}" }}"#,
249            self.token.get(0..8).ok_or(std::fmt::Error)?
250        )
251    }
252}
253
254impl BearerToken {
255    fn new(token: &str) -> Self {
256        Self {
257            token: String::from(token),
258        }
259    }
260}
261
262/// Client object. Must be constructed to talk to the API.
263#[derive(Debug, Clone)]
264pub struct Client {
265    client: reqwest::Client,
266    base_url: String,
267    token: String,
268}
269
270impl Client {
271    // Creates a new `Client` given an api token
272    #[must_use]
273    pub fn new(token: &str) -> Self {
274        Self {
275            client: reqwest::Client::new(),
276            base_url: "https://api.openai.com/v1/".to_string(),
277            token: token.to_string(),
278        }
279    }
280
281    /// Private helper for making gets
282    async fn get<T>(&self, endpoint: &str) -> Result<T>
283    where
284        T: serde::de::DeserializeOwned,
285    {
286        let mut response =
287            self.client
288                .get(endpoint)
289                .header("Authorization", format!("Bearer {}", self.token))
290                .send()
291                .await?;
292
293        if let reqwest::StatusCode::OK = response.status() {
294            Ok(response.json::<T>().await?)
295        } else {
296            let err = response.json::<api::ErrorWrapper>().await?.error;
297            Err(Error::Api(err))
298        }
299    }
300
301    /// Lists the currently available engines.
302    ///
303    /// Provides basic information about each one such as the owner and availability.
304    ///
305    /// # Errors
306    /// - `Error::APIError` if the server returns an error
307    pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
308        self.get(
309            &self.build_url_from_path(
310                &format!(
311                    "engines",
312                ),
313            ),
314        ).await.map(|r: api::Container<_>| r.data)
315    }
316
317    /// Retrieves an engine instance
318    ///
319    /// Provides basic information about the engine such as the owner and availability.
320    ///
321    /// # Errors
322    /// - `Error::APIError` if the server returns an error
323    pub async fn engine(&self, engine: &str) -> Result<api::EngineInfo> {
324        self.get(
325            &self.build_url_from_path(
326                &format!(
327                    "engines/{}",
328                    engine,
329                ),
330            ),
331        ).await
332    }
333
334    // Private helper to generate post requests. Needs to be a bit more flexible than
335    // get because it should support SSE eventually
336    async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
337    where
338        B: serde::ser::Serialize,
339        R: serde::de::DeserializeOwned,
340    {
341        let mut response = self
342            .client
343            .post(endpoint)
344            .header("Authorization", format!("Bearer {}", self.token))
345            .json(&body)
346            .send()
347            .await?;
348
349        match response.status() {
350            reqwest::StatusCode::OK => Ok(response.json::<R>().await?),
351            _ => Err(Error::Api(
352                response
353                    .json::<api::ErrorWrapper>()
354                    .await
355                    .expect("The API has returned something funky")
356                    .error,
357            )),
358        }
359    }
360
361    /// Build an OpenAI API url from a relative path
362    pub fn build_url_from_path(&self, path: &str) -> String {
363        format!("{}{}", self.base_url, path)
364    }
365
366    /// Get predicted completion of the prompt
367    ///
368    /// # Errors
369    ///  - `Error::APIError` if the api returns an error
370    pub async fn complete_prompt(
371        &self,
372        prompt: impl Into<api::CompletionArgs>,
373    ) -> Result<api::Completion> {
374        let args = prompt.into();
375        Ok(self
376            .post(
377                &self.build_url_from_path(
378                    &format!(
379                        "engines/{}/completions",
380                        args.engine,
381                    )
382                ),
383                args,
384            )
385            .await?)
386    }
387}