#![allow(clippy::too_many_lines)]
use std::collections::HashMap;
use std::env;
use std::option::Option;
use std::sync::Arc;
use std::time::Duration;
use chrono::NaiveDateTime;
use eventsource_stream::Eventsource;
use futures::future;
use futures::stream::{Stream, StreamExt};
use lazy_static::lazy_static;
use reqwest::cookie;
use reqwest::header::{self, HeaderMap};
use serde_json::json;
use uuid::Uuid;
use crate::auth::Authenticator;
use crate::config::{Authentication, Config};
use crate::error::GPTError;
use crate::model::{MessageData, Role};
lazy_static! {
static ref BASE_URL: String =
env::var("CHATGPT_BASE_URL").unwrap_or_else(|_| "https://bypass.duti.tech/api/".to_string());
}
pub struct Chatbot {
pub config: Config,
headers: HeaderMap,
session: reqwest::Client,
pub conversation_id: Option<Uuid>,
pub parent_id: Option<Uuid>,
timeout: Option<Duration>,
conversation_mapping: HashMap<Uuid, Uuid>,
conversation_id_prev_queue: Vec<Option<Uuid>>,
parent_id_prev_queue: Vec<Option<Uuid>>,
}
impl Chatbot {
pub async fn new(config: Config) -> Self {
let cookie_store = Arc::new(cookie::Jar::default());
let mut client = reqwest::Client::builder()
.cookie_store(true)
.cookie_provider(cookie_store);
if let Some(proxy) = config.proxy.clone() {
client = client.proxy(proxy);
}
let session = client.build().unwrap();
let mut chatbot = Self {
config,
headers: HeaderMap::new(),
session,
conversation_id: None,
parent_id: None,
timeout: None,
conversation_mapping: HashMap::new(),
conversation_id_prev_queue: Vec::new(),
parent_id_prev_queue: Vec::new(),
};
chatbot.__login().await;
chatbot
}
pub const fn with_conversation_id(mut self, conversation_id: Uuid) -> Self {
self.conversation_id = Some(conversation_id);
self
}
pub const fn with_parent_id(mut self, parent_id: Uuid) -> Self {
self.parent_id = Some(parent_id);
self
}
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
fn __refresh_headers(&mut self, access_token: &str) {
self.headers = [
(header::ACCEPT, "text/event-stream".parse().unwrap()),
(header::AUTHORIZATION, format!("Bearer {access_token}").parse().unwrap()),
(header::CONTENT_TYPE, "application/json".parse().unwrap()),
("X-Openai-Assistant-App-Id".parse().unwrap(), "".parse().unwrap()),
(header::CONNECTION, "close".parse().unwrap()),
(header::ACCEPT_LANGUAGE, "en-US,en;q=0.9".parse().unwrap()),
(header::REFERER, "https://chat.openai.com/chat".parse().unwrap()),
]
.into_iter()
.collect();
}
async fn __login(&mut self) {
match &self.config.authentication {
Authentication::EmailPassword { email, password } => {
log::debug!("Using authentiator to get access token");
let mut auth = Authenticator::new(Some(email), Some(password), self.config.proxy.clone());
auth.begin().await;
auth.get_access_token().await;
self.config.authentication = Authentication::SessionToken {
session_token: auth.session_token.unwrap(),
};
self.__refresh_headers(&auth.access_token.unwrap());
},
Authentication::SessionToken { session_token } => {
log::debug!("Using session token");
let mut auth = Authenticator::new(None, None, self.config.proxy.clone());
auth.session_token = Some(session_token.clone());
auth.get_access_token().await;
self.__refresh_headers(&auth.access_token.unwrap());
},
Authentication::AccessToken { access_token } => {
self.__refresh_headers(&access_token.clone());
},
}
}
pub async fn ask(&mut self, prompt: &str) -> String {
let mut stream = self.ask_stream(prompt).await;
let mut response = String::new();
while let Some(Ok(message)) = stream.next().await {
response = message;
}
response
}
pub async fn ask_stream(&mut self, prompt: &str) -> impl Stream<Item = Result<String, GPTError>> + '_ {
if self.parent_id.is_some() && self.conversation_id.is_none() {
log::debug!("conversation_id must be set once parent_id is set");
panic!();
}
let conversation_id = self.conversation_id;
let mut parent_id = self.parent_id;
if conversation_id.is_none() && parent_id.is_none() {
parent_id = Some(Uuid::new_v4());
log::debug!("New conversation, setting parent_id to new UUID4: {:?}", parent_id);
}
if conversation_id.is_some() && parent_id.is_none() {
if !self.conversation_mapping.contains_key(&conversation_id.unwrap()) {
log::debug!(
"Conversation ID {} not found in conversation mapping, mapping conversations",
conversation_id.unwrap()
);
self.__map_conversations().await;
}
log::debug!(
"Conversation ID {} found in conversation mapping, setting parent_id to {}",
conversation_id.unwrap(),
self.conversation_mapping[&conversation_id.unwrap()]
);
parent_id = Some(self.conversation_mapping[&conversation_id.unwrap()]);
}
let data = json!({
"action": "next",
"messages": [
{
"id": uuid::Uuid::new_v4().to_string(),
"role": "user",
"content": {
"content_type": "text",
"parts": [prompt]
}
}
],
"conversation_id": conversation_id,
"parent_message_id": parent_id,
"model": if self.config.paid {
"text-davinci-002-render-paid"
} else {
"text-davinci-002-render-sha"
}
});
log::debug!("Sending the payload");
log::debug!("{}", serde_json::to_string_pretty(&data).unwrap());
self.conversation_id_prev_queue
.push(serde_json::from_value(data["conversation_id"].clone()).unwrap());
self.parent_id_prev_queue
.push(serde_json::from_value(data["parent_message_id"].clone()).unwrap());
let response = self
.session
.post(format!("{}conversation", *BASE_URL))
.timeout(self.timeout.unwrap_or(Duration::from_secs(360)))
.headers(self.headers.clone())
.json(&data)
.send()
.await
.unwrap()
.bytes_stream()
.eventsource();
response
.map(|ev| -> Result<Option<String>, GPTError> {
let ev = ev.unwrap();
match ev.event.as_str() {
"open" => {
log::debug!("opened");
Ok(None)
},
"message" if ev.data == "[DONE]" => {
log::debug!("📨 {}", ev.data);
Ok(None)
},
"message" => {
log::debug!("📨 {}", ev.data);
match serde_json::from_str::<MessageData>(&ev.data) {
Ok(data) => {
log::debug!("📨 Chunk: {:?}", data.message.content);
let message_data = data.message.content.parts[0].clone();
let conversation_id = data.conversation_id;
let parent_id = data.message.id;
log::debug!("Received message: {:?}", ev);
log::debug!("Received conversation_id: {}", conversation_id);
log::debug!("Received parent_id: {}", parent_id);
self.parent_id = Some(parent_id);
self.conversation_id = Some(conversation_id);
match data.message.author.role {
Role::User | Role::System => Ok(None),
Role::Assistant => Ok(Some(message_data)),
}
},
Err(_) if NaiveDateTime::parse_from_str(&ev.data, "%Y-%m-%d %H:%M:%S%.f").is_ok() => {
log::debug!("📨 Chunk (date): {}", ev.data);
Ok(None)
},
Err(e) => {
log::error!("Bad message: {}", ev.data);
log::error!("Error: {}", e);
Err(GPTError::BadMessage(ev.data))
},
}
},
_ => {
log::debug!("📨 {}", ev.data);
Ok(None)
},
}
})
.filter(move |ev| {
future::ready(match ev {
Ok(Some(s)) if s.is_empty() => false,
Ok(Some(_)) | Err(_) => true,
Ok(None) => false,
})
})
.map(|ev| match ev {
Ok(Some(message)) => Ok(message),
Ok(None) => unreachable!(),
Err(e) => Err(e),
})
}
pub async fn get_conversations(&mut self, offset: i32, limit: i32) -> Vec<serde_json::Value> {
let url = format!("{}conversations?offset={}&limit={}", *BASE_URL, offset, limit);
let response = self
.session
.get(&url)
.headers(self.headers.clone())
.send()
.await
.unwrap();
let data: serde_json::Value = response.json().await.unwrap();
let items: Vec<serde_json::Value> = data["items"].as_array().unwrap().clone();
items
}
pub async fn get_msg_history(&mut self, convo_id: &str) -> serde_json::Value {
let url = format!("{}conversation/{}", *BASE_URL, convo_id);
let response = self
.session
.get(&url)
.headers(self.headers.clone())
.send()
.await
.unwrap();
let data: serde_json::Value = response.json().await.unwrap();
data
}
pub async fn gen_title(&mut self, convo_id: &str, message_id: &str) {
let url = format!("{}conversation/gen_title/{}", *BASE_URL, convo_id);
let response = self
.session
.post(&url)
.headers(self.headers.clone())
.json(&json!({
"message_id": message_id,
"model": "text-davinci-002-render"
}))
.send()
.await
.unwrap();
log::debug!("Title generated: {:?}", response);
}
pub async fn change_title(&mut self, convo_id: &str, title: &str) {
let url = format!("{}conversation/{}", *BASE_URL, convo_id);
let response = self
.session
.patch(&url)
.headers(self.headers.clone())
.json(&json!({ "title": title }))
.send()
.await
.unwrap();
log::debug!("Title changed: {:?}", response);
}
pub async fn delete_conversation(&mut self, convo_id: &str) {
let url = format!("{}conversation/{}", *BASE_URL, convo_id);
let response = self
.session
.patch(&url)
.headers(self.headers.clone())
.json(&json!({
"is_visible": false
}))
.send()
.await
.unwrap();
log::debug!("Conversation deleted: {:?}", response);
}
pub async fn clear_conversations(&mut self) {
let url = format!("{}conversations", *BASE_URL);
let response = self
.session
.patch(&url)
.headers(self.headers.clone())
.json(&json!({
"is_visible": false
}))
.send()
.await
.unwrap();
log::debug!("Conversations cleared: {:?}", response);
}
async fn __map_conversations(&mut self) {
let conversations = self.get_conversations(0, 100).await;
let mut histories = Vec::new();
for x in &conversations {
let y = self.get_msg_history(x["id"].as_str().unwrap()).await;
histories.push(y);
}
for (x, y) in conversations.iter().zip(histories.iter()) {
self.conversation_mapping.insert(
x["id"].as_str().unwrap().to_string().parse().unwrap(),
y["current_node"].as_str().unwrap().to_string().parse().unwrap(),
);
}
}
pub fn reset_chat(&mut self) {
self.conversation_id = None;
self.parent_id = Some(Uuid::new_v4());
}
pub fn rollback_conversation(&mut self, num: usize) {
for _ in 0..num {
self.conversation_id = self.conversation_id_prev_queue.pop().unwrap_or(None);
self.parent_id = self.parent_id_prev_queue.pop().unwrap_or(None);
}
}
}
#[cfg(test)]
mod tests {
use chrono::NaiveDateTime;
#[test]
fn test_datetime_messsage() {
let data = "2023-02-28 15:51:57.358127";
NaiveDateTime::parse_from_str(data, "%Y-%m-%d %H:%M:%S%.f").unwrap();
}
}