use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::Client;
use std::env;
use serde_derive::{Deserialize, Serialize};
use crate::common::*;
use crate::gpt::GptMessage as DeepseekMessage;
use crate::functions::*;
#[derive(Debug, Serialize, Clone)]
pub struct DeepseekCompletion {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<FunctionCall>>,
pub messages: Vec<DeepseekMessage>,
pub temperature: f32,
pub max_tokens: usize,
}
impl DeepseekCompletion {
pub fn new(messages: Vec<DeepseekMessage>, temperature: f32, max_tokens: usize, _is_json: bool) -> Self {
let model: String = env::var("DEEPSEEK_MODEL").expect("DEEPSEEK_MODEL not found in enviroment variables");
DeepseekCompletion {
model,
tools: None,
messages,
temperature,
max_tokens,
}
}
pub fn set_model(&mut self, model: &str) {
self.model = model.into();
}
pub fn set_tools(&mut self, tools: Option<Vec<FunctionCall>>) {
self.tools = tools;
}
pub fn set_max_tokens(&mut self, max_tokens: usize) {
self.max_tokens = max_tokens;
}
pub fn add_message(&mut self, message: &DeepseekMessage) {
self.messages.push(message.clone());
}
pub fn add_messages(&mut self, messages: &[DeepseekMessage]) {
messages.iter().for_each(|m| self.messages.push(m.clone()));
}
}
impl Default for DeepseekCompletion {
fn default() -> Self {
let model: String = env::var("DEEPSEEK_MODEL").expect("DEEPSEEK_MODEL not found in enviroment variables");
DeepseekCompletion {
model,
tools: None,
messages: Vec::new(),
temperature: 0.2,
max_tokens: 4096
}
}
}
impl LlmCompletion for DeepseekCompletion {
fn set_temperature(&mut self, temperature: f32) {
self.temperature = temperature;
}
fn add_text(&mut self, role: &str, text: &str) {
self.messages.push(DeepseekMessage::text(role, text));
}
fn add_many_text(&mut self, role: &str, texts: &[String]) {
self.messages.push(DeepseekMessage::many_text(role, texts));
}
fn add_system(&mut self, system_prompt: &str) {
self.messages.append(&mut DeepseekMessage::system(system_prompt));
}
fn add_multi_part_system(&mut self, system_prompts: &[String]) {
self.messages.append(&mut DeepseekMessage::multi_part_system(system_prompts));
}
fn add_systems(&mut self, system_prompts: &[String]) {
self.messages.append(&mut DeepseekMessage::systems(system_prompts));
}
fn dialogue(&mut self, prompts: &[String], has_system: bool) {
self.messages = DeepseekMessage::dialogue(prompts, has_system);
}
fn truncate_messages(&mut self, len: usize) {
self.messages.truncate(len);
}
fn debug(&self) -> String where Self: std::fmt::Debug {
format!("{:?}", self)
}
async fn call(system: &str, user: &[String], temperature: f32, _is_json: bool, is_chat: bool) -> Result<LlmReturn, Box<dyn std::error::Error + Send>> {
let model: String = env::var("DEEPSEEK_MODEL").expect("DEEPSEEK_MODEL not found in enviroment variables");
Self::call_model(&model, system, user, temperature, _is_json, is_chat).await
}
async fn call_model(model: &str, system: &str, user: &[String], temperature: f32, _is_json: bool, is_chat: bool) -> Result<LlmReturn, Box<dyn std::error::Error + Send>> {
Self::call_model_function(model, system, user, temperature, _is_json, is_chat, None).await
}
async fn call_model_function(model: &str, system: &str, user: &[String], temperature: f32, _is_json: bool, is_chat: bool, function: Option<Vec<Function>>) -> Result<LlmReturn, Box<dyn std::error::Error + Send>> {
let mut messages = Vec::new();
if !system.is_empty() {
messages.push(DeepseekMessage { role: "system".into(), content: system.into() });
}
user.iter()
.enumerate()
.for_each(|(i, c)| {
let role = if !is_chat || i % 2 == 0 { "user" } else { "assistant" };
messages.push(DeepseekMessage { role: role.into(), content: c.to_string() });
});
let fcs = FunctionCall::functions(function);
let completion = DeepseekCompletion {
model: model.into(),
tools: if fcs.is_empty() { None } else { Some(fcs) },
messages,
temperature,
max_tokens: 4096
};
call_deepseek_completion(&completion).await
}
}
#[derive(Debug, Deserialize)]
pub struct DeepseekResponse {
pub id: String,
pub created: usize,
pub model: String,
pub choices: Option<Vec<DeepseekChoice>>,
pub usage: Usage,
}
#[derive(Debug, Deserialize)]
pub struct DeepseekChoice {
pub message: DeepseekMessage,
pub finish_reason: String,
}
#[derive(Debug, Deserialize, Clone)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
impl Usage {
pub fn new() -> Self {
Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }
}
pub fn to_triple(&self) -> (usize, usize, usize) {
(self.prompt_tokens, self.completion_tokens, self.total_tokens)
}
}
impl Default for Usage {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{} + {} = {}", self.prompt_tokens, self.completion_tokens, self.total_tokens)
}
}
pub async fn call_deepseek(messages: Vec<DeepseekMessage>) -> Result<LlmReturn, Box<dyn std::error::Error + Send>> {
call_deepseek_all(messages, 0.2, 4096).await
}
pub async fn call_deepseek_temperature(messages: Vec<DeepseekMessage>, temperature: f32) -> Result<LlmReturn, Box<dyn std::error::Error + Send>> {
call_deepseek_all(messages, temperature, 4096).await
}
pub async fn call_deepseek_max_tokens(messages: Vec<DeepseekMessage>, max_tokens: usize) -> Result<LlmReturn, Box<dyn std::error::Error + Send>> {
call_deepseek_all(messages, 0.2, max_tokens).await
}
pub async fn call_deepseek_all(messages: Vec<DeepseekMessage>, temperature: f32, max_tokens: usize) -> Result<LlmReturn, Box<dyn std::error::Error + Send>> {
let deepseek_completion = DeepseekCompletion::new(messages, temperature, max_tokens, false);
call_deepseek_completion(&deepseek_completion).await
}
pub async fn call_deepseek_completion(deepseek_completion: &DeepseekCompletion) -> Result<LlmReturn, Box<dyn std::error::Error + Send>> {
let start = std::time::Instant::now();
let url: String =
env::var("DEEPSEEK_URL").expect("DEEPSEEK_URL not found in enviroment variables");
let client = get_deepseek_client().await?;
let res = client
.post(url)
.json(&deepseek_completion)
.send()
.await;
let res = res
.map_err(|e| -> Box<dyn std::error::Error + Send> { Box::new(e) })?
.text()
.await
.map_err(|e| -> Box<dyn std::error::Error + Send> { Box::new(e) })?;
let timing = start.elapsed().as_secs() as f64 + start.elapsed().subsec_millis() as f64 / 1000.0;
if res.contains("\"error:\"") {
let ret: Result<LlmError,_> = serde_json::from_str(&res);
match ret {
Ok(res) =>
Ok(LlmReturn::new(LlmType::DEEPSEEK_ERROR, res.error.to_string(), res.error.to_string(), (0, 0, 0), timing, None, None)),
Err(e) => {
eprintln!("Error: {:?}", res);
Ok(LlmReturn::new(LlmType::DEEPSEEK_ERROR, e.to_string(), e.to_string(), (0, 0, 0), timing, None, None))
}
}
} else if res.contains("\"error\"") {
Ok(LlmReturn::new(LlmType::DEEPSEEK_ERROR, res.to_string(), res.to_string(), (0, 0, 0), timing, None, None))
} else if res.contains("\"arguments\":") {
let found = vec!["choices:message:tool_calls:function:arguments:${args}".to_string(),
"choices:message:tool_calls:function:name:${func}".to_string(),
"usage:prompt_tokens:${in}".to_string(),
"usage:completion_tokens:${out}".to_string(),
"usage:total_tokens:${total}".to_string(),
"choices:finish_reason:${finish}".to_string()];
let f: serde_json::Value = serde_json::from_str(&res).unwrap();
let h = get_functions(&f, &found);
let funcs = unpack_functions(h.clone());
let function_calls = serde_json::to_string(&funcs).unwrap();
let (i, o, t) = (h.get("in").unwrap()[0].clone(), h.get("out").unwrap()[0].clone(), h.get("total").unwrap()[0].clone());
let triple = (i.parse::<usize>().unwrap(), o.parse::<usize>().unwrap(), t.parse::<usize>().unwrap());
let finish = h.get("finish").unwrap()[0].clone();
Ok(LlmReturn::new(LlmType::DEEPSEEK_TOOLS, function_calls, finish, triple, timing, None, None))
} else {
let res: DeepseekResponse = serde_json::from_str::<DeepseekResponse>(&res).unwrap();
let (text, finish_reason) =
match res.choices {
Some(choices) => {
if choices.len() > 1 {
eprintln!("There are {:?} choices available now. Code needs to change to reflect this.", choices.len());
}
let text = choices[0].message.content.clone();
let finish_reason = choices[0].finish_reason.to_uppercase().clone();
let text = text.lines().filter(|l| !l.starts_with("```")).fold(String::new(), |s, l| s + l + "\n");
(text, finish_reason)
},
None => {
("None".into(), "ERROR".into())
}
};
let usage: Triple = res.usage.to_triple();
let timing = start.elapsed().as_secs() as f64 + start.elapsed().subsec_millis() as f64 / 1000.0;
Ok(LlmReturn::new(LlmType::DEEPSEEK, text, finish_reason, usage, timing, None, None))
}
}
async fn get_deepseek_client() -> Result<Client, Box<dyn std::error::Error + Send>> {
let api_key: String =
env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not found in enviroment variables");
let mut headers: HeaderMap = HeaderMap::new();
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(|e| -> Box<dyn std::error::Error + Send> { Box::new(e) })?,
);
get_client(headers).await
}
#[cfg(test)]
mod tests {
use super::*;
async fn deepseek(content: Vec<DeepseekMessage>) {
match call_deepseek(content).await {
Ok(ret) => { println!("{ret}"); assert!(true) },
Err(e) => { println!("{e}"); assert!(false) },
}
}
#[tokio::test]
async fn test_call_deepseek_basic() {
let messages: Vec<DeepseekMessage> = vec![DeepseekMessage { role: "user".into(), content: "What is the meaning of life?".into() }];
deepseek(messages).await;
}
#[tokio::test]
async fn test_call_deepseek_citation() {
let messages =
vec![DeepseekMessage::text("user", "Give citations for the General theory of Relativity.")];
deepseek(messages).await;
}
#[tokio::test]
async fn test_call_deepseek_poem() {
let messages =
vec![DeepseekMessage::text("user", "Write a creative poem about the interplay of artificial intelligence and the human spirit and provide citations")];
deepseek(messages).await;
}
#[tokio::test]
async fn test_call_deepseek_logic() {
let messages =
vec![DeepseekMessage::text("user", "How many brains does an octopus have, when they have been injured and lost a leg?")];
deepseek(messages).await;
}
#[tokio::test]
async fn test_call_deepseek_dialogue() {
let system = "Use a Scottish accent to answer questions";
let mut messages =
vec!["How many brains does an octopus have, when they have been injured and lost a leg?".to_string()];
let res = DeepseekCompletion::call(&system, &messages, 0.2, false, true).await;
println!("{res:?}");
messages.push(res.unwrap().to_string());
messages.push("Is a cuttle fish similar?".to_string());
let res = DeepseekCompletion::call(&system, &messages, 0.2, false, true).await;
println!("{res:?}");
}
#[tokio::test]
async fn test_call_deepseek_dialogue_model() {
let model: String = std::env::var("DEEPSEEK_MODEL").expect("DEEPSEEK_MODEL not found in enviroment variables");
let messages = vec!["Hello".to_string()];
let res = DeepseekCompletion::call_model(&model, "", &messages, 0.2, false, true).await;
println!("{res:?}");
}
#[tokio::test]
async fn test_call_function_deepseek() {
let model: String = std::env::var("DEEPSEEK_MODEL").expect("DEEPSEEK_MODEL not found in enviroment variables");
let messages = vec!["The answer is (60 * 24) * 365.25".to_string()];
let func_def =
r#"
// Derive the value of the arithmetic expression
// expr: An arithmetic expression
fn arithmetic(expr)
"#;
let functions = get_function_json("deepseek", &[func_def]);
let res = DeepseekCompletion::call_model_function(&model, "", &messages, 0.2, false, true, functions).await;
println!("{res:?}");
let answer = call_actual_function(res.ok());
println!("{answer:?}");
}
#[tokio::test]
async fn test_call_function_common_deepseek() {
let messages = vec!["a fruit that is blue with a sour tast".to_string()];
let func_def =
r#"
// Derive the value of the arithmetic expression
// expr: An arithmetic expression
fn arithmetic(expr)
"#;
let func_def2 =
r#"
// Find the color of an apple and it's taste pass them to this function.
// color: The color of an apple
// taste: The taste of an apple
fn apple(color, taste)
"#;
let res = call_function_llm("deepseek", &messages, &[func_def, func_def2]).await;
println!("{res:?}");
let answer = call_actual_function(res.ok());
println!("{answer:?}");
}
}