use crate::llm::traits::AiProvider;
use crate::llm::types::{ChatCompletionParams, ProviderExchange, ProviderResponse, TokenUsage};
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::env;
use std::time::{SystemTime, UNIX_EPOCH};
const PRICING: &[(&str, f64, f64, f64)] = &[
("deepseek-chat", 0.07, 0.27, 1.10), ("deepseek-reasoner", 0.14, 0.55, 2.19), ];
fn get_model_pricing(model: &str) -> (f64, f64, f64) {
for (pricing_model, cache_hit, cache_miss, output) in PRICING {
if model.contains(pricing_model) {
return (*cache_hit, *cache_miss, *output);
}
}
(0.07, 0.27, 1.10)
}
fn calculate_cost_with_cache(
model: &str,
regular_input_tokens: u64,
cache_hit_tokens: u64,
completion_tokens: u64,
) -> Option<f64> {
let (cache_hit_price, cache_miss_price, output_price) = get_model_pricing(model);
let regular_input_cost = (regular_input_tokens as f64 / 1_000_000.0) * cache_miss_price;
let cache_hit_cost = (cache_hit_tokens as f64 / 1_000_000.0) * cache_hit_price;
let output_cost = (completion_tokens as f64 / 1_000_000.0) * output_price;
Some(regular_input_cost + cache_hit_cost + output_cost)
}
fn calculate_cost(model: &str, prompt_tokens: u64, completion_tokens: u64) -> Option<f64> {
calculate_cost_with_cache(model, prompt_tokens, 0, completion_tokens)
}
#[derive(Debug, Clone)]
pub struct DeepSeekProvider {
client: Client,
}
impl Default for DeepSeekProvider {
fn default() -> Self {
Self::new()
}
}
impl DeepSeekProvider {
pub fn new() -> Self {
Self {
client: Client::new(),
}
}
}
const DEEPSEEK_API_KEY_ENV: &str = "DEEPSEEK_API_KEY";
#[derive(Serialize, Debug)]
struct DeepSeekRequest {
model: String,
messages: Vec<DeepSeekMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
struct DeepSeekMessage {
role: String,
content: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct DeepSeekResponse {
choices: Vec<DeepSeekChoice>,
usage: Option<DeepSeekUsage>,
}
#[derive(Serialize, Deserialize, Debug)]
struct DeepSeekChoice {
message: DeepSeekMessage,
finish_reason: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
struct DeepSeekUsage {
prompt_tokens: u64,
completion_tokens: u64,
total_tokens: u64,
#[serde(default)]
prompt_cache_hit_tokens: u64,
#[serde(default)]
prompt_cache_miss_tokens: u64,
#[serde(default)]
completion_tokens_details: Option<DeepSeekCompletionTokensDetails>,
}
#[derive(Serialize, Deserialize, Debug, Default)]
struct DeepSeekCompletionTokensDetails {
#[serde(default)]
reasoning_tokens: u64,
}
#[async_trait::async_trait]
impl AiProvider for DeepSeekProvider {
fn name(&self) -> &str {
"deepseek"
}
fn supports_model(&self, model: &str) -> bool {
model.eq_ignore_ascii_case("deepseek-chat")
|| model.eq_ignore_ascii_case("deepseek-reasoner")
}
fn get_api_key(&self) -> Result<String> {
match env::var(DEEPSEEK_API_KEY_ENV) {
Ok(key) => Ok(key),
Err(_) => Err(anyhow::anyhow!(
"DeepSeek API key not found in environment variable: {}",
DEEPSEEK_API_KEY_ENV
)),
}
}
fn supports_caching(&self, _model: &str) -> bool {
true }
fn supports_vision(&self, _model: &str) -> bool {
false }
fn supports_structured_output(&self, _model: &str) -> bool {
true
}
fn get_max_input_tokens(&self, _model: &str) -> usize {
64_000 }
async fn chat_completion(&self, params: ChatCompletionParams) -> Result<ProviderResponse> {
let api_key = self.get_api_key()?;
let messages: Vec<DeepSeekMessage> = params
.messages
.iter()
.map(|msg| DeepSeekMessage {
role: msg.role.clone(),
content: msg.content.clone(),
})
.collect();
let mut request = DeepSeekRequest {
model: params.model.clone(),
messages,
temperature: Some(params.temperature),
max_tokens: Some(params.max_tokens),
stream: Some(false), response_format: None,
};
if let Some(response_format) = ¶ms.response_format {
match &response_format.format {
crate::llm::types::OutputFormat::Json => {
request.response_format = Some(serde_json::json!({
"type": "json_object"
}));
}
crate::llm::types::OutputFormat::JsonSchema => {
request.response_format = Some(serde_json::json!({
"type": "json_object"
}));
}
}
}
let response = self
.client
.post("https://api.deepseek.com/chat/completions")
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"DeepSeek API error {}: {}",
status,
error_text
));
}
let deepseek_response: DeepSeekResponse = response.json().await?;
let response_for_exchange = serde_json::to_value(&deepseek_response)?;
let choice = deepseek_response
.choices
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No choices in DeepSeek response"))?;
let exchange = ProviderExchange {
request: serde_json::to_value(&request)?,
response: response_for_exchange,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
usage: None, provider: self.name().to_string(),
rate_limit_headers: None, };
let token_usage = if let Some(usage) = deepseek_response.usage {
let prompt_tokens = usage.prompt_tokens;
let completion_tokens = usage.completion_tokens;
let total_tokens = usage.total_tokens;
let cache_hit_tokens = usage.prompt_cache_hit_tokens;
let regular_input_tokens = prompt_tokens.saturating_sub(cache_hit_tokens);
let cost = if cache_hit_tokens > 0 {
calculate_cost_with_cache(
¶ms.model,
regular_input_tokens,
cache_hit_tokens,
completion_tokens,
)
} else {
calculate_cost(¶ms.model, prompt_tokens, completion_tokens)
};
let reasoning_tokens = usage
.completion_tokens_details
.as_ref()
.map(|details| details.reasoning_tokens)
.unwrap_or(0);
Some(TokenUsage {
prompt_tokens,
output_tokens: completion_tokens,
reasoning_tokens, total_tokens,
cached_tokens: cache_hit_tokens,
cost,
request_time_ms: None,
})
} else {
None
};
let mut final_exchange = exchange;
final_exchange.usage = token_usage.clone();
let content = &choice.message.content;
let structured_output =
if content.trim().starts_with('{') || content.trim().starts_with('[') {
serde_json::from_str(content).ok()
} else {
None
};
Ok(ProviderResponse {
content: choice.message.content,
thinking: None, exchange: final_exchange,
tool_calls: None, finish_reason: choice.finish_reason,
structured_output,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_supports_model() {
let provider = DeepSeekProvider::new();
assert!(provider.supports_model("deepseek-chat"));
assert!(provider.supports_model("deepseek-reasoner"));
assert!(!provider.supports_model("gpt-4"));
assert!(!provider.supports_model("deepseek-coder")); }
#[test]
fn test_supports_model_case_insensitive() {
let provider = DeepSeekProvider::new();
assert!(provider.supports_model("DEEPSEEK-CHAT"));
assert!(provider.supports_model("DEEPSEEK-REASONER"));
assert!(provider.supports_model("DeepSeek-Chat"));
assert!(provider.supports_model("DEEPSEEK-reasoner"));
}
#[test]
fn test_calculate_cost() {
let cost = calculate_cost("deepseek-chat", 1_000_000, 500_000);
assert!(cost.is_some());
let cost_value = cost.unwrap();
let expected = 0.27 + (0.5 * 1.10);
assert!((cost_value - expected).abs() < 0.01);
let cost2 = calculate_cost("deepseek-reasoner", 1_000_000, 500_000);
assert!(cost2.is_some());
let expected2 = 0.55 + (0.5 * 2.19);
assert!((cost2.unwrap() - expected2).abs() < 0.01);
}
#[test]
fn test_calculate_cost_with_cache() {
let cost = calculate_cost_with_cache("deepseek-chat", 500_000, 500_000, 250_000);
assert!(cost.is_some());
let cost_value = cost.unwrap();
let expected = (0.5 * 0.27) + (0.5 * 0.07) + (0.25 * 1.10);
assert!((cost_value - expected).abs() < 0.01);
let cost_no_cache = calculate_cost("deepseek-chat", 1_000_000, 250_000);
assert!(cost_no_cache.is_some());
assert!(cost_value < cost_no_cache.unwrap());
}
}