use anyhow::{anyhow, Result};
use async_trait::async_trait;
use log::info;
use reqwest::{header, Client};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::completions::ThinkingLevel;
use crate::constants::DEEPSEEK_API_URL;
use crate::domain::{DeepSeekAPICompletionsResponse, RateLimit};
use crate::llm_models::{LLMModel, LLMTools};
use crate::utils::map_to_range_f32;
#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
pub enum DeepSeekModels {
DeepSeekChat,
DeepSeekReasoner,
}
#[async_trait(?Send)]
impl LLMModel for DeepSeekModels {
fn as_str(&self) -> &str {
match self {
DeepSeekModels::DeepSeekChat => "deepseek-chat",
DeepSeekModels::DeepSeekReasoner => "deepseek-reasoner",
}
}
fn try_from_str(name: &str) -> Option<Self> {
match name.to_lowercase().as_str() {
"deepseek-chat" => Some(DeepSeekModels::DeepSeekChat),
"deepseek-reasoner" => Some(DeepSeekModels::DeepSeekReasoner),
_ => None,
}
}
fn default_max_tokens(&self) -> usize {
match self {
DeepSeekModels::DeepSeekChat => 8_192,
DeepSeekModels::DeepSeekReasoner => 8_192,
}
}
fn get_endpoint(&self) -> String {
DEEPSEEK_API_URL.to_string()
}
fn get_body(
&self,
instructions: &str,
json_schema: &Value,
function_call: bool,
max_tokens: &usize,
temperature: &f32,
_tools: Option<&[LLMTools]>,
_thinking_level: Option<&ThinkingLevel>,
) -> serde_json::Value {
let base_instructions = self.get_base_instructions(Some(function_call));
let system_message = json!({
"role": "system",
"content": base_instructions,
});
let user_message = json!({
"role": "user",
"content": format!(
"<instructions>
{instructions}
</instructions>
<output json schema>
{json_schema}
</output json schema>"
),
});
json!({
"model": self.as_str(),
"max_tokens": max_tokens,
"temperature": temperature,
"messages": vec![
system_message,
user_message,
],
})
}
async fn call_api(
&self,
api_key: &str,
_version: Option<String>,
body: &serde_json::Value,
debug: bool,
_tools: Option<&[LLMTools]>,
) -> Result<String> {
let model_url = self.get_endpoint();
if debug {
info!("[debug] DeepSeek API URL: {:#?}", model_url);
}
let client = Client::new();
let response = client
.post(model_url)
.header(header::CONTENT_TYPE, "application/json")
.bearer_auth(api_key)
.json(&body)
.send()
.await?;
let response_status = response.status();
let response_text = response.text().await?;
if debug {
info!(
"[debug] DeepSeek API response: [{}] {:#?}",
&response_status, &response_text
);
}
Ok(response_text)
}
fn get_data(&self, response_text: &str, _function_call: bool) -> Result<String> {
let completions_response: DeepSeekAPICompletionsResponse =
serde_json::from_str(response_text)?;
completions_response
.choices
.iter()
.filter_map(|choice| choice.message.as_ref())
.find(|&message| message.role == Some("assistant".to_string()))
.and_then(|message| {
message
.content
.as_ref()
.map(|content| self.sanitize_json_response(content))
})
.ok_or_else(|| anyhow!("Assistant role content not found"))
}
fn get_rate_limit(&self) -> RateLimit {
RateLimit {
tpm: 100_000_000, rpm: 100_000_000,
}
}
fn get_normalized_temperature(&self, relative_temp: u32) -> f32 {
let min = 0.0f32;
let max = 1.5f32;
map_to_range_f32(min, max, relative_temp)
}
}