chatgpt_api 0.3.0

Use ChatGPT easily from Rust. Or from the command line.
Documentation
#![allow(clippy::too_many_lines)]

use std::collections::HashMap;
use std::env;
use std::option::Option;
use std::sync::Arc;
use std::time::Duration;

use chrono::NaiveDateTime;
use eventsource_stream::Eventsource;
use futures::future;
use futures::stream::{Stream, StreamExt};
use lazy_static::lazy_static;
use reqwest::cookie;
use reqwest::header::{self, HeaderMap};
use serde_json::json;
use uuid::Uuid;

use crate::auth::Authenticator;
use crate::config::{Authentication, Config};
use crate::error::GPTError;
use crate::model::{MessageData, Role};

lazy_static! {
    static ref BASE_URL: String =
        env::var("CHATGPT_BASE_URL").unwrap_or_else(|_| "https://bypass.duti.tech/api/".to_string());
}

pub struct Chatbot {
    pub config: Config,
    headers: HeaderMap,
    session: reqwest::Client,
    pub conversation_id: Option<Uuid>,
    pub parent_id: Option<Uuid>,
    timeout: Option<Duration>,
    conversation_mapping: HashMap<Uuid, Uuid>,
    conversation_id_prev_queue: Vec<Option<Uuid>>,
    parent_id_prev_queue: Vec<Option<Uuid>>,
}

impl Chatbot {
    pub async fn new(config: Config) -> Self {
        // Create cookie store
        let cookie_store = Arc::new(cookie::Jar::default());
        let mut client = reqwest::Client::builder()
            .cookie_store(true)
            .cookie_provider(cookie_store);

        // Add proxy if provided
        if let Some(proxy) = config.proxy.clone() {
            client = client.proxy(proxy);
        }

        // Build client
        let session = client.build().unwrap();

        // Create chatbot
        let mut chatbot = Self {
            config,
            headers: HeaderMap::new(),
            session,
            conversation_id: None,
            parent_id: None,
            timeout: None,
            conversation_mapping: HashMap::new(),
            conversation_id_prev_queue: Vec::new(),
            parent_id_prev_queue: Vec::new(),
        };

        // Login
        chatbot.__login().await;

        chatbot
    }

    pub const fn with_conversation_id(mut self, conversation_id: Uuid) -> Self {
        self.conversation_id = Some(conversation_id);
        self
    }

    pub const fn with_parent_id(mut self, parent_id: Uuid) -> Self {
        self.parent_id = Some(parent_id);
        self
    }

    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
        self.timeout = Some(timeout);
        self
    }

    fn __refresh_headers(&mut self, access_token: &str) {
        self.headers = [
            (header::ACCEPT, "text/event-stream".parse().unwrap()),
            (header::AUTHORIZATION, format!("Bearer {access_token}").parse().unwrap()),
            (header::CONTENT_TYPE, "application/json".parse().unwrap()),
            ("X-Openai-Assistant-App-Id".parse().unwrap(), "".parse().unwrap()),
            (header::CONNECTION, "close".parse().unwrap()),
            (header::ACCEPT_LANGUAGE, "en-US,en;q=0.9".parse().unwrap()),
            (header::REFERER, "https://chat.openai.com/chat".parse().unwrap()),
        ]
        .into_iter()
        .collect();
    }

    async fn __login(&mut self) {
        match &self.config.authentication {
            Authentication::EmailPassword { email, password } => {
                log::debug!("Using authentiator to get access token");
                let mut auth = Authenticator::new(Some(email), Some(password), self.config.proxy.clone());
                auth.begin().await;
                auth.get_access_token().await;
                self.config.authentication = Authentication::SessionToken {
                    session_token: auth.session_token.unwrap(),
                };
                self.__refresh_headers(&auth.access_token.unwrap());
            },
            Authentication::SessionToken { session_token } => {
                log::debug!("Using session token");
                let mut auth = Authenticator::new(None, None, self.config.proxy.clone());
                auth.session_token = Some(session_token.clone());
                auth.get_access_token().await;
                self.__refresh_headers(&auth.access_token.unwrap());
            },
            Authentication::AccessToken { access_token } => {
                self.__refresh_headers(&access_token.clone());
            },
        }
    }

    pub async fn ask(&mut self, prompt: &str) -> String {
        let mut stream = self.ask_stream(prompt).await;
        let mut response = String::new();
        while let Some(Ok(message)) = stream.next().await {
            response = message;
        }
        response
    }

    pub async fn ask_stream(&mut self, prompt: &str) -> impl Stream<Item = Result<String, GPTError>> + '_ {
        if self.parent_id.is_some() && self.conversation_id.is_none() {
            log::debug!("conversation_id must be set once parent_id is set");
            panic!();
        }

        let conversation_id = self.conversation_id;
        let mut parent_id = self.parent_id;
        if conversation_id.is_none() && parent_id.is_none() {
            // New conversation
            parent_id = Some(Uuid::new_v4());
            log::debug!("New conversation, setting parent_id to new UUID4: {:?}", parent_id);
        }

        if conversation_id.is_some() && parent_id.is_none() {
            if !self.conversation_mapping.contains_key(&conversation_id.unwrap()) {
                log::debug!(
                    "Conversation ID {} not found in conversation mapping, mapping conversations",
                    conversation_id.unwrap()
                );
                self.__map_conversations().await;
            }
            log::debug!(
                "Conversation ID {} found in conversation mapping, setting parent_id to {}",
                conversation_id.unwrap(),
                self.conversation_mapping[&conversation_id.unwrap()]
            );
            parent_id = Some(self.conversation_mapping[&conversation_id.unwrap()]);
        }

        let data = json!({
            "action": "next",
            "messages": [
                {
                    "id": uuid::Uuid::new_v4().to_string(),
                    "role": "user",
                    "content": {
                        "content_type": "text",
                        "parts": [prompt]
                    }
                }
            ],
            "conversation_id": conversation_id,
            "parent_message_id": parent_id,
            "model": if self.config.paid {
                "text-davinci-002-render-paid"
            } else {
                "text-davinci-002-render-sha"
            }
        });

        log::debug!("Sending the payload");
        log::debug!("{}", serde_json::to_string_pretty(&data).unwrap());
        self.conversation_id_prev_queue
            .push(serde_json::from_value(data["conversation_id"].clone()).unwrap());
        self.parent_id_prev_queue
            .push(serde_json::from_value(data["parent_message_id"].clone()).unwrap());

        let response = self
            .session
            .post(format!("{}conversation", *BASE_URL))
            .timeout(self.timeout.unwrap_or(Duration::from_secs(360)))
            .headers(self.headers.clone())
            .json(&data)
            .send()
            .await
            .unwrap()
            .bytes_stream()
            .eventsource();

        response
            .map(|ev| -> Result<Option<String>, GPTError> {
                // let ev = ev.map_err(GPTError::ConnectionError)?;
                let ev = ev.unwrap();
                match ev.event.as_str() {
                    "open" => {
                        log::debug!("opened");
                        Ok(None)
                    },
                    "message" if ev.data == "[DONE]" => {
                        log::debug!("📨 {}", ev.data);
                        Ok(None)
                    },
                    "message" => {
                        log::debug!("📨 {}", ev.data);
                        match serde_json::from_str::<MessageData>(&ev.data) {
                            Ok(data) => {
                                log::debug!("📨 Chunk: {:?}", data.message.content);

                                let message_data = data.message.content.parts[0].clone();
                                let conversation_id = data.conversation_id;
                                let parent_id = data.message.id;

                                log::debug!("Received message: {:?}", ev);
                                log::debug!("Received conversation_id: {}", conversation_id);
                                log::debug!("Received parent_id: {}", parent_id);

                                self.parent_id = Some(parent_id);
                                self.conversation_id = Some(conversation_id);

                                match data.message.author.role {
                                    Role::User | Role::System => Ok(None),
                                    Role::Assistant => Ok(Some(message_data)),
                                }
                            },
                            Err(_) if NaiveDateTime::parse_from_str(&ev.data, "%Y-%m-%d %H:%M:%S%.f").is_ok() => {
                                log::debug!("📨 Chunk (date): {}", ev.data);
                                Ok(None)
                            },
                            Err(e) => {
                                log::error!("Bad message: {}", ev.data);
                                log::error!("Error: {}", e);
                                Err(GPTError::BadMessage(ev.data))
                            },
                        }
                    },
                    _ => {
                        log::debug!("📨 {}", ev.data);
                        Ok(None)
                    },
                }
            })
            .filter(move |ev| {
                future::ready(match ev {
                    Ok(Some(s)) if s.is_empty() => false,
                    Ok(Some(_)) | Err(_) => true,
                    Ok(None) => false,
                })
            })
            .map(|ev| match ev {
                Ok(Some(message)) => Ok(message),
                Ok(None) => unreachable!(),
                Err(e) => Err(e),
            })
    }

    pub async fn get_conversations(&mut self, offset: i32, limit: i32) -> Vec<serde_json::Value> {
        let url = format!("{}conversations?offset={}&limit={}", *BASE_URL, offset, limit);
        let response = self
            .session
            .get(&url)
            .headers(self.headers.clone())
            .send()
            .await
            .unwrap();
        let data: serde_json::Value = response.json().await.unwrap();
        let items: Vec<serde_json::Value> = data["items"].as_array().unwrap().clone();
        items
    }

    pub async fn get_msg_history(&mut self, convo_id: &str) -> serde_json::Value {
        let url = format!("{}conversation/{}", *BASE_URL, convo_id);
        let response = self
            .session
            .get(&url)
            .headers(self.headers.clone())
            .send()
            .await
            .unwrap();
        let data: serde_json::Value = response.json().await.unwrap();
        data
    }

    pub async fn gen_title(&mut self, convo_id: &str, message_id: &str) {
        let url = format!("{}conversation/gen_title/{}", *BASE_URL, convo_id);
        let response = self
            .session
            .post(&url)
            .headers(self.headers.clone())
            .json(&json!({
                "message_id": message_id,
                "model": "text-davinci-002-render"
            }))
            .send()
            .await
            .unwrap();
        log::debug!("Title generated: {:?}", response);
    }

    pub async fn change_title(&mut self, convo_id: &str, title: &str) {
        let url = format!("{}conversation/{}", *BASE_URL, convo_id);
        let response = self
            .session
            .patch(&url)
            .headers(self.headers.clone())
            .json(&json!({ "title": title }))
            .send()
            .await
            .unwrap();
        log::debug!("Title changed: {:?}", response);
    }

    pub async fn delete_conversation(&mut self, convo_id: &str) {
        let url = format!("{}conversation/{}", *BASE_URL, convo_id);
        let response = self
            .session
            .patch(&url)
            .headers(self.headers.clone())
            .json(&json!({
                "is_visible": false
            }))
            .send()
            .await
            .unwrap();
        log::debug!("Conversation deleted: {:?}", response);
    }

    pub async fn clear_conversations(&mut self) {
        let url = format!("{}conversations", *BASE_URL);
        let response = self
            .session
            .patch(&url)
            .headers(self.headers.clone())
            .json(&json!({
                "is_visible": false
            }))
            .send()
            .await
            .unwrap();
        log::debug!("Conversations cleared: {:?}", response);
    }

    async fn __map_conversations(&mut self) {
        let conversations = self.get_conversations(0, 100).await;
        let mut histories = Vec::new();
        for x in &conversations {
            let y = self.get_msg_history(x["id"].as_str().unwrap()).await;
            histories.push(y);
        }
        for (x, y) in conversations.iter().zip(histories.iter()) {
            self.conversation_mapping.insert(
                x["id"].as_str().unwrap().to_string().parse().unwrap(),
                y["current_node"].as_str().unwrap().to_string().parse().unwrap(),
            );
        }
    }

    pub fn reset_chat(&mut self) {
        self.conversation_id = None;
        self.parent_id = Some(Uuid::new_v4());
    }

    pub fn rollback_conversation(&mut self, num: usize) {
        for _ in 0..num {
            self.conversation_id = self.conversation_id_prev_queue.pop().unwrap_or(None);
            self.parent_id = self.parent_id_prev_queue.pop().unwrap_or(None);
        }
    }
}

#[cfg(test)]
mod tests {
    use chrono::NaiveDateTime;

    #[test]
    fn test_datetime_messsage() {
        let data = "2023-02-28 15:51:57.358127";
        NaiveDateTime::parse_from_str(data, "%Y-%m-%d %H:%M:%S%.f").unwrap();
    }
}