rust-gpt 0.0.3

A library for interaction with the Completion/Chat OpenAI API.
Documentation
//! # Chat API
//!
//! The chat API is used to have a conversation with the GPT-3.5 model which runs ChatGPT.  
//!
//! The main structs used in here are [`ChatResponse`] and [`ChatMessage`].
//!
//! ## Chat
//! This is a new experimental struct that allows you to have a conversation with the GPT-3.5 model.
//! It will automatically remember the messages you send and the messages the model sends so the model can remember the conversation.
//!
//! See the [`ChatBuilder`] and [`Chat`] structs for more information.
use std::{collections::VecDeque, error::Error};
use tokio::sync::Mutex;

use serde::{Deserialize, Serialize};

use crate::SendRequest;

#[derive(Debug, Clone, Serialize, Deserialize)]
/// Represents one of the messages sent to or received from the chat API.
pub struct ChatMessage {
    pub role: Role,
    pub content: Option<String>,
}

impl Default for ChatMessage {
    fn default() -> Self {
        Self {
            role: Role::User,
            content: Some(String::new()),
        }
    }
}

#[derive(Debug, Deserialize, Serialize)]
/// Represents the usage information returned by the chat API.
pub struct Usage {
    pub prompt_tokens: u32,
    pub completion_tokens: u32,
    pub total_tokens: u32,
}

#[derive(Debug, Deserialize, Serialize)]
/// Represents the choice object returned by the chat API.
pub struct ChatChoice {
    pub index: u32,
    pub message: ChatMessage,
    pub finish_reason: Option<String>,
}

#[derive(Debug, Deserialize, Serialize)]
/// Represents a response from the chat API.
pub struct ChatResponse {
    pub id: String,
    pub object: String,
    pub created: u64,
    pub choices: Vec<ChatChoice>,
    pub usage: Usage,
}
#[derive(Debug, Clone)]
/// Represents one of the roles that can be used in the chat API.
pub enum Role {
    User,
    Assistant,
    System,
}

impl Serialize for Role {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        serializer.serialize_str(&self.to_string())
    }
}

impl<'de> Deserialize<'de> for Role {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let s = String::deserialize(deserializer)?;

        Role::try_from(s.as_str()).map_err(serde::de::Error::custom)
    }
}

impl ToString for Role {
    fn to_string(&self) -> String {
        match self {
            Role::User => "user",
            Role::Assistant => "assistant",
            Role::System => "system",
        }
        .to_string()
    }
}

impl TryFrom<&str> for Role {
    type Error = Box<dyn Error>;

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        match value {
            "user" => Ok(Role::User),
            "assistant" => Ok(Role::Assistant),
            "system" => Ok(Role::System),
            _ => Err("Invalid Role".into()),
        }
    }
}

// ----------------------------------------------------
// new unstable chat thing

/// Builds a [`Chat`] struct for initiating a chat session.
pub struct ChatBuilder {
    system: ChatMessage,
    chat_parameters: ChatParameters,
    api_key: String,
    model: crate::ChatModel,
    len: usize,
}

impl ChatBuilder {
    /// Creates a new [`ChatBuilder`] with the given model and API key.
    pub fn new(model: crate::ChatModel, api_key: String) -> Self {
        let default_msg = ChatMessage {
            role: Role::System,
            ..Default::default()
        };

        ChatBuilder {
            model,
            api_key,
            system: default_msg,
            chat_parameters: ChatParameters::default(),
            len: 5,
        }
    }

    /// Sets the amount of user messages that are stored in the chat session.
    pub fn len(mut self, len: usize) -> Self {
        self.len = len;
        self
    }

    /// Sets the system message that is sent to the chat API
    pub fn system(mut self, system: ChatMessage) -> Self {
        self.system = system;
        self
    }

    /// Sets the temperature
    pub fn temperature(mut self, temperature: f32) -> Self {
        self.chat_parameters.temperature = Some(temperature);
        self
    }

    /// Sets the maximum amount of tokens that can be generated by the chat API.
    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
        self.chat_parameters.max_tokens = Some(max_tokens);
        self
    }

    /// Sets the top_p parameter
    pub fn top_p(mut self, top_p: f32) -> Self {
        self.chat_parameters.top_p = Some(top_p);
        self
    }

    /// Sets the presence penalty
    pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
        self.chat_parameters.presence_penalty = Some(presence_penalty);
        self
    }

    /// Sets the frequency penalty
    pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
        self.chat_parameters.frequency_penalty = Some(frequency_penalty);
        self
    }

    /// Sets the user
    pub fn user(mut self, user: String) -> Self {
        self.chat_parameters.user = Some(user);
        self
    }

    /// Builds the [`Chat`] struct.
    pub fn build(self) -> Chat {
        Chat::new(
            self.system,
            self.model,
            self.len,
            self.api_key,
            self.chat_parameters,
        )
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[doc(hidden)]
#[derive(Default)]
pub struct ChatParameters {
    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub max_tokens: Option<u32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub presence_penalty: Option<f32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub frequency_penalty: Option<f32>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub user: Option<String>,
}

/// A struct that represents a chat session
/// This struct makes it easy to interact with the chat api, as well as remembering messages.
/// This struct guarantees that messages are sent and stored in the order that [`ask`] is called.
///
/// [`ask`]: #method.ask
///
/// Requests to the API are only sent when [`get_response`] is called.
///
/// [`get_response`]: #method.get_response
///
/// You can build a new chat session with [`ChatBuilder`].
///
/// [`ChatBuilder`]: ./struct.ChatBuilder.html
pub struct Chat {
    system: ChatMessage,
    chat_parameters: ChatParameters,
    api_key: String,
    model: crate::ChatModel,
    len: usize,
    messages: Mutex<VecDeque<ChatMessage>>,
    message_queue: Mutex<VecDeque<ChatMessage>>,
}

impl Chat {
    fn new<T: ToString>(
        system: ChatMessage,
        model: crate::ChatModel,
        len: usize,
        api_key: T,
        chat_parameters: ChatParameters,
    ) -> Self {
        Self {
            system,
            chat_parameters,
            api_key: api_key.to_string(),
            model,
            len: len * 2 + 2,
            messages: Mutex::new(VecDeque::new()),
            message_queue: Mutex::new(VecDeque::new()),
        }
    }

    /// Get the messages that have been sent and received including the system and assistan messages.
    pub async fn get_messages(&self) -> Vec<ChatMessage> {
        let mut messages = self.messages.lock().await.clone();

        messages.push_front(self.system.clone());

        messages.into()
    }

    /// Adds a message to the queue to be sent to the API.
    pub async fn ask(&self, message: &str) -> Result<(), Box<dyn Error>> {
        let msg = ChatMessage {
            role: Role::User,
            content: Some(message.to_string()),
        };

        self.message_queue.lock().await.push_back(msg);
        Ok(())
    }

    /// Sends the message history to the API including the last question asked, and returns the response.
    pub async fn get_response(&self, user: Option<String>) -> Result<ChatMessage, Box<dyn Error>> {

        // the pushing and popping is in reverse order because we want to order the messages
        // in the API from oldest to newest.

        let msg = if let Some(message) = self.message_queue.lock().await.pop_front() {
            message
        } else {
            return Err("No message to send".into());
        };

        let mut messages = self.messages.lock().await;

        if messages.len() >= self.len {
            messages.pop_front();
            messages.pop_front();
            // pop the oldest user + assistant message
        }

        messages.push_back(msg.clone());

        let mut to_send = messages.clone();
        to_send.push_front(self.system.clone());

        let builder = crate::RequestBuilder::new(self.model.clone(), self.api_key.clone())
            .messages(to_send.into())
            .chat_parameters(self.chat_parameters.clone());

        let builder = if let Some(user) = user {
            builder.user(user)
        } else {
            builder
        };

        let req = builder.build_chat();

        let resp = match req.send().await {
            Ok(resp) => resp,
            Err(e) => {
                messages.pop_back(); // remove the message we just added
                return Err(e.into());
            }
        };

        let message = resp.choices[0].message.clone();

        messages.push_back(message.clone());

        Ok(message)
    }
}