use std::{collections::VecDeque, error::Error};
use tokio::sync::Mutex;
use serde::{Deserialize, Serialize};
use crate::SendRequest;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
pub content: Option<String>,
}
impl Default for ChatMessage {
fn default() -> Self {
Self {
role: Role::User,
content: Some(String::new()),
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatResponse {
pub id: String,
pub object: String,
pub created: u64,
pub choices: Vec<ChatChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone)]
pub enum Role {
User,
Assistant,
System,
}
impl Serialize for Role {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for Role {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Role::try_from(s.as_str()).map_err(serde::de::Error::custom)
}
}
impl ToString for Role {
fn to_string(&self) -> String {
match self {
Role::User => "user",
Role::Assistant => "assistant",
Role::System => "system",
}
.to_string()
}
}
impl TryFrom<&str> for Role {
type Error = Box<dyn Error>;
fn try_from(value: &str) -> Result<Self, Self::Error> {
match value {
"user" => Ok(Role::User),
"assistant" => Ok(Role::Assistant),
"system" => Ok(Role::System),
_ => Err("Invalid Role".into()),
}
}
}
pub struct ChatBuilder {
system: ChatMessage,
chat_parameters: ChatParameters,
api_key: String,
model: crate::ChatModel,
len: usize,
}
impl ChatBuilder {
pub fn new(model: crate::ChatModel, api_key: String) -> Self {
let default_msg = ChatMessage {
role: Role::System,
..Default::default()
};
ChatBuilder {
model,
api_key,
system: default_msg,
chat_parameters: ChatParameters::default(),
len: 5,
}
}
pub fn len(mut self, len: usize) -> Self {
self.len = len;
self
}
pub fn system(mut self, system: ChatMessage) -> Self {
self.system = system;
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.chat_parameters.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.chat_parameters.max_tokens = Some(max_tokens);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.chat_parameters.top_p = Some(top_p);
self
}
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.chat_parameters.presence_penalty = Some(presence_penalty);
self
}
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.chat_parameters.frequency_penalty = Some(frequency_penalty);
self
}
pub fn user(mut self, user: String) -> Self {
self.chat_parameters.user = Some(user);
self
}
pub fn build(self) -> Chat {
Chat::new(
self.system,
self.model,
self.len,
self.api_key,
self.chat_parameters,
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[doc(hidden)]
#[derive(Default)]
pub struct ChatParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
pub struct Chat {
system: ChatMessage,
chat_parameters: ChatParameters,
api_key: String,
model: crate::ChatModel,
len: usize,
messages: Mutex<VecDeque<ChatMessage>>,
message_queue: Mutex<VecDeque<ChatMessage>>,
}
impl Chat {
fn new<T: ToString>(
system: ChatMessage,
model: crate::ChatModel,
len: usize,
api_key: T,
chat_parameters: ChatParameters,
) -> Self {
Self {
system,
chat_parameters,
api_key: api_key.to_string(),
model,
len: len * 2 + 2,
messages: Mutex::new(VecDeque::new()),
message_queue: Mutex::new(VecDeque::new()),
}
}
pub async fn get_messages(&self) -> Vec<ChatMessage> {
let mut messages = self.messages.lock().await.clone();
messages.push_front(self.system.clone());
messages.into()
}
pub async fn ask(&self, message: &str) -> Result<(), Box<dyn Error>> {
let msg = ChatMessage {
role: Role::User,
content: Some(message.to_string()),
};
self.message_queue.lock().await.push_back(msg);
Ok(())
}
pub async fn get_response(&self, user: Option<String>) -> Result<ChatMessage, Box<dyn Error>> {
let msg = if let Some(message) = self.message_queue.lock().await.pop_front() {
message
} else {
return Err("No message to send".into());
};
let mut messages = self.messages.lock().await;
if messages.len() >= self.len {
messages.pop_front();
messages.pop_front();
}
messages.push_back(msg.clone());
let mut to_send = messages.clone();
to_send.push_front(self.system.clone());
let builder = crate::RequestBuilder::new(self.model.clone(), self.api_key.clone())
.messages(to_send.into())
.chat_parameters(self.chat_parameters.clone());
let builder = if let Some(user) = user {
builder.user(user)
} else {
builder
};
let req = builder.build_chat();
let resp = match req.send().await {
Ok(resp) => resp,
Err(e) => {
messages.pop_back(); return Err(e.into());
}
};
let message = resp.choices[0].message.clone();
messages.push_back(message.clone());
Ok(message)
}
}