mod history_message;
extern crate toml;
use std::fs::File;
use std::collections::VecDeque;
use std::error::Error;
use std::io::Read;
use regex::Regex;
use reqwest;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use futures::stream::StreamExt;
lazy_static::lazy_static! {
static ref UNICODE_REGEX: regex::Regex = regex::Regex::new(r"\\u[0-9a-fA-F]{4}").unwrap();
}
#[derive(Serialize, Deserialize, Debug)]
struct AiResponse {
language_model: Option<String>,
system_role: Option<String>,
system_content: Option<String>,
user_role: Option<String>,
assistant_role: Option<String>,
max_tokens: Option<f64>,
temp_float: Option<f64>,
top_p_float: Option<f64>,
}
#[derive(Serialize, Deserialize, Debug)]
struct AiConfig {
ai_config_glm3: Vec<AiResponse>,
ai_config_glm4: Vec<AiResponse>,
}
fn sse_read_config(file_path: &str, glm: &str) -> Result<String, Box<dyn Error>> {
let mut file = File::open(file_path)?;
let mut file_content = String::new();
file.read_to_string(&mut file_content)?;
let config: AiConfig = toml::from_str(&file_content)?;
let response = match glm {
"glm-3" => config.ai_config_glm3,
"glm-4" => config.ai_config_glm4,
_ => return Err(Box::from("Invalid glm")),
};
let json_string = serde_json::to_string(&response)?;
Ok(json_string)
}
pub struct MessageProcessor {
messages: history_message::HistoryMessage,
}
impl MessageProcessor {
pub fn new() -> Self {
MessageProcessor {
messages: history_message::HistoryMessage::new(),
}
}
pub fn set_input_message(&self) -> Option<String> {
let message = self.messages.load_history_from_file();
if !message.is_empty() {
Some(message)
} else {
None
}
}
pub fn last_messages(&self, role: &str, messages: &str) -> String {
let input_message = self.set_input_message().unwrap_or_default();
let mut input: Value = serde_json::from_str(&input_message).unwrap_or_default();
input["role"] = Value::String(role.to_string());
input["content"] = Value::String(messages.to_string());
let texts = serde_json::to_string(&input).unwrap_or_default();
let regex = Regex::new(r",(\s*})").expect("Failed to create regex pattern");
let user_messages = input_message.clone() + &texts.clone();
let result = regex.replace_all(&user_messages, "");
result.to_string()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SSEInvokeModel {
get_message: String,
ai_response_data: String,
}
impl SSEInvokeModel {
pub fn new() -> Self {
SSEInvokeModel {
get_message: String::new(),
ai_response_data: String::new(),
}
}
pub async fn sse_request(token: String, input: String, user_config: &str, default_url: String) -> Result<String, Box<dyn Error>> {
let mut sse_invoke_model = Self::new();
Self::sse_invoke_request_method(&mut sse_invoke_model, token.clone(), input.clone(), user_config, default_url.clone()).await?;
let response_message = sse_invoke_model.ai_response_data.clone();
let result = sse_invoke_model.process_sse_message(&*response_message, &input);
Ok(result)
}
async fn generate_sse_json_request_body(
language_model: &str,
system_role: &str,
system_content: &str,
user_role: &str,
user_input: &str,
max_token: f64,
temp_float: f64,
top_p_float: f64,
) -> Result<String, Box<dyn Error>> {
let message_process = MessageProcessor::new();
let messages = json!([
{"role": system_role, "content": system_content},
{"role": user_role, "content": message_process.last_messages(user_role,user_input)}
]);
let json_request_body = json!({
"model": language_model,
"messages": messages,
"stream": true,
"do_sample":true,
"max_tokens":max_token,
"temperature": temp_float,
"top_p": top_p_float
});
let json_string = serde_json::to_string(&json_request_body)?;
let result = json_string.replace(r"\\\\", r"\\").replace(r"\\", r"").trim().to_string();
Ok(result)
}
pub async fn sse_invoke_request_method(
&mut self,
token: String,
user_input: String,
user_config: &str,
default_url: String,
) -> Result<String, String> {
let json_string = match sse_read_config(user_config, "glm-4") {
Ok(json_string) => json_string,
Err(err) => return Err(format!("Error reading config file: {}", err)),
};
let json_value: Value = serde_json::from_str(&json_string)
.expect("Failed to parse Toml to JSON");
let language_model = json_value[0]["language_model"]
.as_str().expect("Failed to get language_model").to_string();
let system_role = json_value[0]["system_role"]
.as_str().expect("Failed to get system_role").to_string();
let system_content = json_value[0]["system_content"]
.as_str().expect("Failed to get system_content").to_string().trim().to_string();
let user_role = json_value[0]["user_role"]
.as_str().expect("Failed to get user_role").to_string();
let max_token = json_value[0]["max_tokens"]
.as_f64().expect("Failed to get max_token");
let temp_float = json_value[0]["temp_float"]
.as_f64().expect("Failed to get temp_float");
let top_p_float = json_value[0]["top_p_float"]
.as_f64().expect("Failed to get top_p_float");
let json_content = match Self::generate_sse_json_request_body(
&language_model,
&system_role,
&system_content,
&user_role,
&user_input,
max_token,
temp_float,
top_p_float,
).await {
Ok(result) => result.to_string(),
Err(err) => return Err(err.to_string()),
};
let request_result = reqwest::Client::new()
.post(&default_url)
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.header("Accept", "text/event-stream")
.header("Content-Type", "application/json;charset=UTF-8")
.header("Authorization", format!("Bearer {}", token))
.body(json_content.clone())
.send()
.await
.map_err(|err| format!("HTTP request failure: {}", err))?;
if !request_result.status().is_success() {
return Err(format!("Server returned an error: {}", request_result.status()));
}
let mut response_body = request_result.bytes_stream();
let mut sse_data = String::new();
while let Some(chunk) = response_body.next().await {
match chunk {
Ok(bytes) => {
let data = String::from_utf8_lossy(&bytes);
sse_data.push_str(&data);
self.ai_response_data = sse_data.clone();
if data.contains("data: [DONE]") {
break;
}
}
Err(e) => {
return Err(format!("Error receiving SSE event: {}", e));
}
}
}
Ok(sse_data)
}
fn process_sse_message(&mut self, response_data: &str, user_message: &str) -> String {
let mut char_queue = VecDeque::new();
let mut queue_result = String::new();
let json_messages: Vec<&str> = response_data.lines()
.map(|line| line.trim_start_matches("data: "))
.filter(|line| !line.is_empty())
.collect();
for json_message in json_messages {
if json_message.trim() == "[DONE]" {
break;
}
if let Ok(json_element) = serde_json::from_str::<Value>(json_message) {
if let Some(json_response) = json_element.as_object() {
if let Some(choices) = json_response.get("choices").and_then(Value::as_array) {
if let Some(choice) = choices.get(0).and_then(Value::as_object) {
if let Some(delta) = choice.get("delta").and_then(Value::as_object) {
if let Some(content) = delta.get("content").and_then(Value::as_str) {
let get_message = self.convert_unicode_emojis(content)
.replace("\"", "")
.replace("\\n\\n", "\n")
.replace("\\nn", "\n")
.replace("\\\\n", "\n")
.replace("\\\\nn", "\n")
.replace("\\", "");
for c in get_message.chars() {
char_queue.push_back(c);
}
}
}
}
}
} else {
println!("Invalid JSON format: {:?}", json_element);
}
} else {
println!("Error reading JSON: {}", json_message);
}
}
queue_result.extend(char_queue);
if !queue_result.is_empty() {
let message_process = history_message::HistoryMessage::new();
message_process.add_history_to_file("user", user_message);
message_process.add_history_to_file("assistant", &*queue_result);
}
queue_result
}
fn convert_unicode_emojis(&self, input: &str) -> String {
UNICODE_REGEX.replace_all(input, |caps: ®ex::Captures| {
let emoji = char::from_u32(
u32::from_str_radix(&caps[0][2..], 16).expect("Failed to parse Unicode escape"),
)
.expect("Invalid Unicode escape");
emoji.to_string()
})
.to_string()
}
}