openai-api-fork 0.2.1

OpenAI API library for rust
Documentation
/// `OpenAI` API client library
#[macro_use]
extern crate derive_builder;

use thiserror::Error;

type Result<T> = std::result::Result<T, Error>;

#[allow(clippy::default_trait_access)]
pub mod api {
    //! Data types corresponding to requests and responses from the API
    use std::{collections::HashMap, convert::TryFrom, fmt::Display};

    use serde::{Deserialize, Serialize};

    /// Container type. Used in the api, but not useful for clients of this library
    #[derive(Deserialize, Debug)]
    pub(crate) struct Container<T> {
        /// Items in the page's results
        pub data: Vec<T>,
    }

    /// Detailed information on a particular engine.
    #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
    pub struct EngineInfo {
        /// The name of the engine, e.g. `"davinci"` or `"ada"`
        pub id: String,
        /// The owner of the model. Usually (always?) `"openai"`
        pub owner: String,
        /// Whether the model is ready for use. Usually (always?) `true`
        pub ready: bool,
    }

    /// Options for the query completion
    #[derive(Serialize, Debug, Builder, Clone)]
    #[builder(pattern = "immutable")]
    pub struct CompletionArgs {
        /// The id of the engine to use for this request
        ///
        /// # Example
        /// ```
        /// # use openai_api::api::CompletionArgs;
        /// CompletionArgs::builder().engine("davinci");
        /// ```
        #[builder(setter(into), default = "\"davinci\".into()")]
        #[serde(skip_serializing)]
        pub(super) engine: String,
        /// The prompt to complete from.
        ///
        /// Defaults to `"<|endoftext|>"` which is a special token seen during training.
        ///
        /// # Example
        /// ```
        /// # use openai_api::api::CompletionArgs;
        /// CompletionArgs::builder().prompt("Once upon a time...");
        /// ```
        #[builder(setter(into), default = "\"<|endoftext|>\".into()")]
        prompt: String,
        /// Maximum number of tokens to complete.
        ///
        /// Defaults to 16
        /// # Example
        /// ```
        /// # use openai_api::api::CompletionArgs;
        /// CompletionArgs::builder().max_tokens(64);
        /// ```
        #[builder(default = "16")]
        max_tokens: u64,
        /// What sampling temperature to use.
        ///
        /// Default is `1.0`
        ///
        /// Higher values means the model will take more risks.
        /// Try 0.9 for more creative applications, and 0 (argmax sampling)
        /// for ones with a well-defined answer.
        ///
        /// OpenAI recommends altering this or top_p but not both.
        ///
        /// # Example
        /// ```
        /// # use openai_api::api::{CompletionArgs, CompletionArgsBuilder};
        /// # use std::convert::{TryInto, TryFrom};
        /// # fn main() -> Result<(), String> {
        /// let builder = CompletionArgs::builder().temperature(0.7);
        /// let args: CompletionArgs = builder.try_into()?;
        /// # Ok::<(), String>(())
        /// # }
        /// ```
        #[builder(default = "1.0")]
        temperature: f64,
        #[builder(default = "1.0")]
        top_p: f64,
        #[builder(default = "1")]
        n: u64,
        #[builder(setter(strip_option), default)]
        logprobs: Option<u64>,
        #[builder(default = "false")]
        echo: bool,
        #[builder(setter(strip_option), default)]
        stop: Option<Vec<String>>,
        #[builder(default = "0.0")]
        presence_penalty: f64,
        #[builder(default = "0.0")]
        frequency_penalty: f64,
        #[builder(default)]
        logit_bias: HashMap<String, f64>,
    }

    // TODO: add validators for the different arguments

    impl From<&str> for CompletionArgs {
        fn from(prompt_string: &str) -> Self {
            Self {
                prompt: prompt_string.into(),
                ..CompletionArgsBuilder::default()
                    .build()
                    .expect("default should build")
            }
        }
    }

    impl CompletionArgs {
        /// Build a `CompletionArgs` from the defaults
        #[must_use]
        pub fn builder() -> CompletionArgsBuilder {
            CompletionArgsBuilder::default()
        }
    }

    impl TryFrom<CompletionArgsBuilder> for CompletionArgs {
        type Error = CompletionArgsBuilderError;

        fn try_from(builder: CompletionArgsBuilder) -> Result<Self, Self::Error> {
            builder.build()
        }
    }

    /// Represents a non-streamed completion response
    #[derive(Deserialize, Debug, Clone)]
    pub struct Completion {
        /// Completion unique identifier
        pub id: String,
        /// Unix timestamp when the completion was generated
        pub created: u64,
        /// Exact model type and version used for the completion
        pub model: String,
        /// List of completions generated by the model
        pub choices: Vec<Choice>,
    }

    impl std::fmt::Display for Completion {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            write!(f, "{}", self.choices[0])
        }
    }

    /// A single completion result
    #[derive(Deserialize, Debug, Clone)]
    pub struct Choice {
        /// The text of the completion. Will contain the prompt if echo is True.
        pub text: String,
        /// Offset in the result where the completion began. Useful if using echo.
        pub index: u64,
        /// If requested, the log probabilities of the completion tokens
        pub logprobs: Option<LogProbs>,
        /// Why the completion ended when it did
        pub finish_reason: String,
    }

    impl std::fmt::Display for Choice {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            self.text.fmt(f)
        }
    }

    /// Represents a logprobs subdocument
    #[derive(Deserialize, Debug, Clone)]
    pub struct LogProbs {
        pub tokens: Vec<String>,
        pub token_logprobs: Vec<Option<f64>>,
        pub top_logprobs: Vec<Option<HashMap<String, f64>>>,
        pub text_offset: Vec<u64>,
    }

    /// Error response object from the server
    #[derive(Deserialize, Debug, Eq, PartialEq, Clone)]
    pub struct ErrorMessage {
        pub message: String,
        #[serde(rename = "type")]
        pub error_type: String,
    }

    impl Display for ErrorMessage {
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
            self.message.fmt(f)
        }
    }

    /// API-level wrapper used in deserialization
    #[derive(Deserialize, Debug)]
    pub(crate) struct ErrorWrapper {
        pub error: ErrorMessage,
    }
}

/// This library's main `Error` type.
#[derive(Error, Debug)]
pub enum Error {
    /// An error returned by the API itself
    #[error("API returned an Error: {}", .0.message)]
    Api(api::ErrorMessage),
    /// An error the client discovers before talking to the API
    #[error("Bad arguments: {0}")]
    BadArguments(String),
    /// Network / protocol related errors
    #[error("Error at the protocol level: {0}")]
    AsyncProtocol(reqwest::Error),
}

impl From<api::ErrorMessage> for Error {
    fn from(e: api::ErrorMessage) -> Self {
        Error::Api(e)
    }
}

impl From<String> for Error {
    fn from(e: String) -> Self {
        Error::BadArguments(e)
    }
}

impl From<reqwest::Error> for Error {
    fn from(e: reqwest::Error) -> Self {
        Error::AsyncProtocol(e)
    }
}

/// Authentication middleware
struct BearerToken {
    token: String,
}

impl std::fmt::Debug for BearerToken {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // Get the first few characters to help debug, but not accidentally log key
        write!(
            f,
            r#"Bearer {{ token: "{}" }}"#,
            self.token.get(0..8).ok_or(std::fmt::Error)?
        )
    }
}

impl BearerToken {
    fn new(token: &str) -> Self {
        Self {
            token: String::from(token),
        }
    }
}

/// Client object. Must be constructed to talk to the API.
#[derive(Debug, Clone)]
pub struct Client {
    client: reqwest::Client,
    base_url: String,
    token: String,
}

impl Client {
    // Creates a new `Client` given an api token
    #[must_use]
    pub fn new(token: &str) -> Self {
        Self {
            client: reqwest::Client::new(),
            base_url: "https://api.openai.com/v1/".to_string(),
            token: token.to_string(),
        }
    }

    /// Private helper for making gets
    async fn get<T>(&self, endpoint: &str) -> Result<T>
    where
        T: serde::de::DeserializeOwned,
    {
        let mut response =
            self.client
                .get(endpoint)
                .header("Authorization", format!("Bearer {}", self.token))
                .send()
                .await?;

        if let reqwest::StatusCode::OK = response.status() {
            Ok(response.json::<T>().await?)
        } else {
            let err = response.json::<api::ErrorWrapper>().await?.error;
            Err(Error::Api(err))
        }
    }

    /// Lists the currently available engines.
    ///
    /// Provides basic information about each one such as the owner and availability.
    ///
    /// # Errors
    /// - `Error::APIError` if the server returns an error
    pub async fn engines(&self) -> Result<Vec<api::EngineInfo>> {
        self.get(
            &self.build_url_from_path(
                &format!(
                    "engines",
                ),
            ),
        ).await.map(|r: api::Container<_>| r.data)
    }

    /// Retrieves an engine instance
    ///
    /// Provides basic information about the engine such as the owner and availability.
    ///
    /// # Errors
    /// - `Error::APIError` if the server returns an error
    pub async fn engine(&self, engine: &str) -> Result<api::EngineInfo> {
        self.get(
            &self.build_url_from_path(
                &format!(
                    "engines/{}",
                    engine,
                ),
            ),
        ).await
    }

    // Private helper to generate post requests. Needs to be a bit more flexible than
    // get because it should support SSE eventually
    async fn post<B, R>(&self, endpoint: &str, body: B) -> Result<R>
    where
        B: serde::ser::Serialize,
        R: serde::de::DeserializeOwned,
    {
        let mut response = self
            .client
            .post(endpoint)
            .header("Authorization", format!("Bearer {}", self.token))
            .json(&body)
            .send()
            .await?;

        match response.status() {
            reqwest::StatusCode::OK => Ok(response.json::<R>().await?),
            _ => Err(Error::Api(
                response
                    .json::<api::ErrorWrapper>()
                    .await
                    .expect("The API has returned something funky")
                    .error,
            )),
        }
    }

    /// Build an OpenAI API url from a relative path
    pub fn build_url_from_path(&self, path: &str) -> String {
        format!("{}{}", self.base_url, path)
    }

    /// Get predicted completion of the prompt
    ///
    /// # Errors
    ///  - `Error::APIError` if the api returns an error
    pub async fn complete_prompt(
        &self,
        prompt: impl Into<api::CompletionArgs>,
    ) -> Result<api::Completion> {
        let args = prompt.into();
        Ok(self
            .post(
                &self.build_url_from_path(
                    &format!(
                        "engines/{}/completions",
                        args.engine,
                    )
                ),
                args,
            )
            .await?)
    }
}