openai_gpt_rs 0.4.1

A wrapper for OpenAI API written in Rust
Documentation
use crate::args::*;
use crate::error::ResponseError;
use crate::response::*;
use reqwest::header::{HeaderMap, AUTHORIZATION, CONTENT_TYPE};
use reqwest::Client as HttpClient;
use reqwest::Error;
use serde_json::{json, Value};

pub struct Client {
    client: HttpClient,
    api_key: String,
    header: HeaderMap,
}

impl Client {
    /// Creates a new client with the given api key.
    pub fn new(key: &str) -> Client {
        let mut header = HeaderMap::new();
        header.insert(CONTENT_TYPE, "application/json".parse().unwrap());
        header.insert(AUTHORIZATION, format!("Bearer {}", key).parse().unwrap());

        Client {
            client: HttpClient::new(),
            api_key: String::from(key),
            header,
        }
    }

    /// Makes an api call to OpenAI Completion API and returns the response.
    ///
    /// # Arguments
    ///
    /// * `f` - A closure that takes a mutable reference to `CompletionArgs` and returns it.
    ///
    /// # Example
    ///
    /// ```
    /// use openai_gpt_rs::{args::CompletionArgs, client::Client, response::{CompletionResp, Content}, models::CompletionModels};
    /// use std::env;
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let client = Client::new(env::var("OPENAI_API_KEY").unwrap().as_str());
    ///
    ///     let resp = client.create_completion(|args| {
    ///                     args.prompt("This is a test")
    ///                         .model(CompletionModels::TextDavinci3)
    ///                         .max_tokens(32)
    ///                         .n(5)
    ///                })
    ///            .await
    ///           .unwrap();
    ///
    ///     let text = resp.get_contents(0..5);
    ///
    ///     for val in text {
    ///        assert!(!val.is_empty());
    ///    }
    /// }
    ///
    /// ```
    ///
    /// # Panics
    /// This function will panic if the request to OpenAI fails.
    ///
    pub async fn create_completion<T>(&self, f: T) -> Result<CompletionResp, ResponseError>
    where
        T: FnOnce(&mut CompletionArgs) -> &mut CompletionArgs,
    {
        let mut args = CompletionArgs::default();
        f(&mut args);

        let body: Value = json!({
            "model": args.model,
            "prompt": args.prompt,
            "max_tokens": args.max_tokens,
            "temperature": args.temperature,
            "top_p": 1,
            "n": args.n,
            "stream": false,
            "logprobs": null
        });

        let resp = self
            .client
            .post("https://api.openai.com/v1/completions")
            .headers(self.header.clone())
            .json(&body)
            .send()
            .await
            .unwrap();

        let json: Value = resp.json().await.unwrap();

        if let Some(e) = json.as_object().unwrap().get("error") {
            return Err(serde_json::from_value(e.clone()).unwrap());
        }

        Ok(serde_json::from_value(json).unwrap())
    }

    /// Makes an api call to OpenAI Edit API and returns the response.
    ///
    /// # Arguments
    ///
    /// * `f` - A closure that takes a mutable reference to `EditArgs` and returns it.
    ///
    /// # Example
    /// ```
    /// use openai_gpt_rs::{args::EditArgs, client::Client, response::{EditResp, Content}, models::EditModels};
    /// use std::env;
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let client = Client::new(env::var("OPENAI_API_KEY").unwrap().as_str());
    ///
    ///     let resp = client.create_edit(|args| {
    ///                     args.model(EditModels::TextDavinciEdit1)
    ///                         .input("How is you dae")
    ///                         .instruction("Fix the spelling mistakes")
    ///                         .n(5)
    ///                 })
    ///                .await
    ///                .unwrap();
    ///
    ///     let text = resp.get_contents(0..5);
    ///
    ///     for val in text {
    ///         assert!(!val.is_empty());
    ///     }
    /// }
    /// ```
    /// # Panics
    /// This function will panic if the request to OpenAI fails.
    ///
    pub async fn create_edit<T>(&self, f: T) -> Result<EditResp, ResponseError>
    where
        T: FnOnce(&mut EditArgs) -> &mut EditArgs,
    {
        let mut args = EditArgs::default();
        f(&mut args);

        let body = json!({
            "model": args.model,
            "input": args.input,
            "instruction": args.instruction,
            "n": args.n,
            "temperature": args.temperature,
            "top_p": args.top_p
        });

        let resp = self
            .client
            .post("https://api.openai.com/v1/edits")
            .headers(self.header.clone())
            .body(body.to_string())
            .send()
            .await
            .unwrap();

        let json: Value = resp.json().await.unwrap();

        if let Some(e) = json.as_object().unwrap().get("error") {
            return Err(serde_json::from_value(e.clone()).unwrap());
        }

        Ok(serde_json::from_value(json).unwrap())
    }

    /// Makes an api call to OpenAI Image API and returns the response.
    ///
    /// # Arguments
    ///
    /// * `f` - A closure that takes a mutable reference to `ImageArgs` and returns it.
    ///
    /// # Example
    /// ```
    /// use openai_gpt_rs::{args::{ImageArgs, ImageSize}, client::Client, response::{ImageResp, Content}};
    /// use std::env;
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let client = Client::new(env::var("OPENAI_API_KEY").unwrap().as_str());
    ///     
    ///     let resp = client.create_image(|args| {
    ///                         args.prompt("Kitty")
    ///                             .size(ImageSize::Small)
    ///                             .n(2)
    ///                     })
    ///                     .await
    ///                     .unwrap();
    ///
    ///     let urls = resp.get_contents(0..2);
    ///
    ///     for val in urls {
    ///         assert!(!val.is_empty());
    ///     }
    /// }
    /// ```
    /// # Panics
    /// This function will panic if the request to OpenAI fails.
    ///
    pub async fn create_image<T>(&self, f: T) -> Result<ImageResp, ResponseError>
    where
        T: FnOnce(&mut ImageArgs) -> &mut ImageArgs,
    {
        let mut args = ImageArgs::default();
        f(&mut args);

        let body = json!({
            "model": "image-alpha-001",
            "prompt": args.prompt,
            "n": args.n,
            "size": args.size,
            "response_format": args.response_format
        });

        let resp = self
            .client
            .post("https://api.openai.com/v1/images/generations")
            .headers(self.header.clone())
            .body(body.to_string())
            .send()
            .await
            .unwrap();

        let json: Value = resp.json().await.unwrap();

        if let Some(e) = json.as_object().unwrap().get("error") {
            return Err(serde_json::from_value(e.clone()).unwrap());
        }

        Ok(serde_json::from_value(json).unwrap())
    }

    /// Returns the client's api key.
    pub fn get_key(&self) -> &String {
        &self.api_key
    }

    /// Sets the client's api key to the value of given key.
    pub fn set_key(&mut self, new_key: &str) {
        self.api_key = new_key.to_string();
    }

    /// Returns a json listing all the models
    pub async fn get_models(&self) -> Result<Value, Error> {
        let resp = self
            .client
            .get("https://api.openai.com/v1/models")
            .headers(self.header.clone())
            .send()
            .await?
            .json::<Value>()
            .await?;

        Ok(resp)
    }

    /// Makes an api call to OpenAI Chat Completion API and returns the response.
    ///
    /// # Arguments
    ///
    /// * `f` - A closure that takes a mutable reference to `ChatArgs` and returns it.
    ///
    /// # Example
    /// ```
    /// use openai_gpt_rs::{args::ChatArgs, client::Client, response::{ChatResp, Content}, models::ChatModels, chat::Message};
    /// use std::env;
    /// use std::collections::HashMap;
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let client = Client::new(env::var("OPENAI_API_KEY").unwrap().as_str());
    ///
    ///     let message1 = Message {
    ///         role: "user".to_string(),
    ///         content: "Who won the world series in 2020?".to_string(),
    ///     };
    ///
    ///     let message2 = Message {
    ///         role: "system".to_string(),
    ///         content: "You are a helpful assistant.".to_string(),
    ///     };
    ///
    ///     let messages = vec![message1, message2];
    ///
    ///     let resp = client
    ///         .create_chat_completion(|args| args.messages(messages.clone()))
    ///         .await
    ///         .unwrap();
    ///
    ///     let contents = resp.get_content(0).unwrap();
    ///
    ///     assert!(!contents.is_empty());
    /// }
    /// ```
    /// # Errors
    /// This function will return an error if the api call fails.
    /// The error will be of type `reqwest::Error`.
    ///     
    pub async fn create_chat_completion<T>(&self, f: T) -> Result<ChatResp, ResponseError>
    where
        T: FnOnce(&mut ChatArgs) -> &mut ChatArgs,
    {
        let mut args = ChatArgs {
            model: "gpt-3.5-turbo".to_string(),
            messages: vec![],
            n: 1,
            temperature: 1.0,
            top_p: 1.0,
            max_tokens: 32,
            presence_penalty: 0.0,
            frequency_penalty: 0.0,
        };
        f(&mut args);

        let body = json!({
        "model": args.model,
        "messages": args.messages,
        "max_tokens": args.max_tokens,
        "temperature": args.temperature,
        "top_p": args.top_p,
        "n": args.n,
        "presence_penalty": args.presence_penalty,
        "frequency_penalty": args.frequency_penalty
        });

        let resp = self
            .client
            .post("https://api.openai.com/v1/chat/completions")
            .headers(self.header.clone())
            .json(&body)
            .send()
            .await
            .unwrap();

        let json: Value = resp.json().await.unwrap();

        if let Some(e) = json.as_object().unwrap().get("error") {
            return Err(serde_json::from_value(e.clone()).unwrap());
        }

        Ok(serde_json::from_value(json).unwrap())
    }
}