Skip to main content

langhub_core/llms/
deepseek.rs

1/// DeepSeek (DeepSeek-V3, R1)
2use crate::types::*;
3
4use super::{ChatMessage, LLM, LLMOptions, LLMResult};
5use serde_json::json;
6use std::future::Future;
7use std::pin::Pin;
8
9#[derive(Debug, Clone)]
10pub enum DeepSeekModel {
11    Chat,     // DeepSeek-V3
12    Coder,    // DeepSeek-Coder
13    Reasoner, // DeepSeek-R1
14}
15
16impl DeepSeekModel {
17    fn as_str(&self) -> &'static str {
18        match self {
19            DeepSeekModel::Chat => "deepseek-chat",
20            DeepSeekModel::Coder => "deepseek-coder",
21            DeepSeekModel::Reasoner => "deepseek-reasoner",
22        }
23    }
24}
25
26impl From<DeepSeekModel> for String {
27    fn from(model: DeepSeekModel) -> Self {
28        model.as_str().to_string()
29    }
30}
31
32pub struct DeepSeek {
33    api_key: String,
34    model: DeepSeekModel,
35    base_url: String,
36    client: reqwest::Client,
37    default_options: LLMOptions,
38}
39
40impl DeepSeek {
41    pub fn new(api_key: String) -> Self {
42        Self {
43            api_key,
44            model: DeepSeekModel::Chat,
45            base_url: "https://api.deepseek.com/v1".to_string(),
46            client: reqwest::Client::new(),
47            default_options: LLMOptions::default(),
48        }
49    }
50
51    pub fn with_model(mut self, model: DeepSeekModel) -> Self {
52        self.model = model;
53        self
54    }
55
56    pub fn chat_model(self) -> Self {
57        self.with_model(DeepSeekModel::Chat)
58    }
59
60    pub fn coder_model(self) -> Self {
61        self.with_model(DeepSeekModel::Coder)
62    }
63
64    pub fn reasoner_model(self) -> Self {
65        self.with_model(DeepSeekModel::Reasoner)
66    }
67
68    pub fn with_temperature(mut self, temperature: f32) -> Self {
69        self.default_options.temperature = Some(temperature);
70        self
71    }
72
73    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
74        self.default_options.max_tokens = Some(max_tokens);
75        self
76    }
77
78    pub fn with_top_p(mut self, top_p: f32) -> Self {
79        self.default_options.top_p = Some(top_p);
80        self
81    }
82
83    pub fn with_base_url(mut self, base_url: &str) -> Self {
84        self.base_url = base_url.to_string();
85        self
86    }
87
88    async fn chat_completion(
89        &self,
90        messages: &[ChatMessage],
91        options: &LLMOptions,
92    ) -> Result<String> {
93        let model_name: String = self.model.clone().into();
94
95        let mut request_body = json!({
96            "model": model_name,
97            "messages": messages.iter().map(|m| json!({
98                "role": m.role,
99                "content": m.content,
100            })).collect::<Vec<_>>(),
101        });
102
103        if let Some(temp) = options.temperature.or(self.default_options.temperature) {
104            request_body["temperature"] = json!(temp);
105        }
106        if let Some(max_tokens) = options.max_tokens.or(self.default_options.max_tokens) {
107            request_body["max_tokens"] = json!(max_tokens);
108        }
109        if let Some(top_p) = options.top_p.or(self.default_options.top_p) {
110            request_body["top_p"] = json!(top_p);
111        }
112        if let Some(freq_penalty) = options
113            .frequency_penalty
114            .or(self.default_options.frequency_penalty)
115        {
116            request_body["frequency_penalty"] = json!(freq_penalty);
117        }
118        if let Some(pres_penalty) = options
119            .presence_penalty
120            .or(self.default_options.presence_penalty)
121        {
122            request_body["presence_penalty"] = json!(pres_penalty);
123        }
124        if let Some(stop) = options
125            .stop_sequences
126            .as_ref()
127            .or(self.default_options.stop_sequences.as_ref())
128        {
129            request_body["stop"] = json!(stop);
130        }
131
132        let response = self
133            .client
134            .post(format!("{}/chat/completions", self.base_url))
135            .header("Authorization", format!("Bearer {}", self.api_key))
136            .header("Content-Type", "application/json")
137            .json(&request_body)
138            .send()
139            .await
140            .map_err(|e| LangHubError::LLMError(format!("DeepSeek request error: {}", e)))?;
141
142        if !response.status().is_success() {
143            let status = response.status();
144            let error_text = response.text().await.unwrap_or_default();
145            return Err(LangHubError::LLMError(format!(
146                "DeepSeek API error ({}): {}",
147                status, error_text
148            )));
149        }
150
151        let json: serde_json::Value = response
152            .json()
153            .await
154            .map_err(|e| LangHubError::LLMError(format!("JSON parse error: {}", e)))?;
155
156        let text = json["choices"][0]["message"]["content"]
157            .as_str()
158            .ok_or_else(|| {
159                LangHubError::ParseError("Missing 'content' field in response".to_string())
160            })?
161            .to_string();
162
163        Ok(text)
164    }
165}
166
167impl LLM for DeepSeek {
168    fn generate(
169        &self,
170        prompt: &str,
171    ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
172        let prompt = prompt.to_string();
173        let options = self.default_options.clone();
174        Box::pin(async move {
175            let messages = vec![ChatMessage::user(&prompt)];
176            let text = self.chat_completion(&messages, &options).await?;
177            Ok(LLMResult {
178                text,
179                metadata: None,
180            })
181        })
182    }
183
184    fn generate_with_options(
185        &self,
186        prompt: &str,
187        options: LLMOptions,
188    ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
189        let prompt = prompt.to_string();
190        Box::pin(async move {
191            let messages = vec![ChatMessage::user(&prompt)];
192            let text = self.chat_completion(&messages, &options).await?;
193            Ok(LLMResult {
194                text,
195                metadata: None,
196            })
197        })
198    }
199
200    fn chat(
201        &self,
202        messages: Vec<ChatMessage>,
203    ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
204        Box::pin(async move {
205            let text = self
206                .chat_completion(&messages, &LLMOptions::default())
207                .await?;
208            Ok(LLMResult {
209                text,
210                metadata: None,
211            })
212        })
213    }
214
215    fn get_model_name(&self) -> &str {
216        match self.model {
217            DeepSeekModel::Chat => "deepseek-chat",
218            DeepSeekModel::Coder => "deepseek-coder",
219            DeepSeekModel::Reasoner => "deepseek-reasoner",
220        }
221    }
222
223    fn get_provider_name(&self) -> &str {
224        match self.model {
225            DeepSeekModel::Chat => "DeepSeek-V3",
226            DeepSeekModel::Coder => "DeepSeek-Coder",
227            DeepSeekModel::Reasoner => "DeepSeek-R1",
228        }
229    }
230
231    fn get_provider_enum(&self) -> ModelProvider {
232        ModelProvider::DeepSeek
233    }
234
235    fn supports_function_calling(&self) -> bool {
236        true
237    }
238
239    fn supports_json_mode(&self) -> bool {
240        true
241    }
242
243    fn max_context_length(&self) -> Option<usize> {
244        Some(0)
245    }
246}