rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::client::{CompletionClient, ProviderClient};
13use crate::json_utils::merge;
14use crate::providers::openai;
15use crate::providers::openai::send_compatible_streaming_request;
16use crate::streaming::StreamingCompletionResponse;
17use crate::{
18    completion::{self, CompletionError, CompletionModel, CompletionRequest},
19    impl_conversion_traits, json_utils, message, OneOrMany,
20};
21use reqwest::Client as HttpClient;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25// ================================================================
26// Main DeepSeek Client
27// ================================================================
28const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
29
30#[derive(Clone, Debug)]
31pub struct Client {
32    pub base_url: String,
33    http_client: HttpClient,
34}
35
36impl Client {
37    // Create a new DeepSeek client from an API key.
38    pub fn new(api_key: &str) -> Self {
39        Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
40    }
41
42    // Handy for advanced usage, e.g. letting user override base_url or set timeouts:
43    pub fn from_url(api_key: &str, base_url: &str) -> Self {
44        // Possibly configure a custom HTTP client here if needed.
45        Self {
46            base_url: base_url.to_string(),
47            http_client: reqwest::Client::builder()
48                .default_headers({
49                    let mut headers = reqwest::header::HeaderMap::new();
50                    headers.insert(
51                        "Authorization",
52                        format!("Bearer {api_key}")
53                            .parse()
54                            .expect("Bearer token should parse"),
55                    );
56                    headers
57                })
58                .build()
59                .expect("DeepSeek reqwest client should build"),
60        }
61    }
62
63    fn post(&self, path: &str) -> reqwest::RequestBuilder {
64        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
65        self.http_client.post(url)
66    }
67}
68
69impl ProviderClient for Client {
70    // If you prefer the environment variable approach:
71    fn from_env() -> Self {
72        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
73        Self::new(&api_key)
74    }
75}
76
77impl CompletionClient for Client {
78    type CompletionModel = DeepSeekCompletionModel;
79
80    /// Creates a DeepSeek completion model with the given `model_name`.
81    fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
82        DeepSeekCompletionModel {
83            client: self.clone(),
84            model: model_name.to_string(),
85        }
86    }
87}
88
89impl_conversion_traits!(
90    AsEmbeddings,
91    AsTranscription,
92    AsImageGeneration,
93    AsAudioGeneration for Client
94);
95
96#[derive(Debug, Deserialize)]
97struct ApiErrorResponse {
98    message: String,
99}
100
101#[derive(Debug, Deserialize)]
102#[serde(untagged)]
103enum ApiResponse<T> {
104    Ok(T),
105    Err(ApiErrorResponse),
106}
107
108impl From<ApiErrorResponse> for CompletionError {
109    fn from(err: ApiErrorResponse) -> Self {
110        CompletionError::ProviderError(err.message)
111    }
112}
113
114/// The response shape from the DeepSeek API
115#[derive(Clone, Debug, Serialize, Deserialize)]
116pub struct CompletionResponse {
117    // We'll match the JSON:
118    pub choices: Vec<Choice>,
119    // you may want usage or other fields
120}
121
122#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
123pub struct Choice {
124    pub index: usize,
125    pub message: Message,
126    pub logprobs: Option<serde_json::Value>,
127    pub finish_reason: String,
128}
129
130#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
131#[serde(tag = "role", rename_all = "lowercase")]
132pub enum Message {
133    System {
134        content: String,
135        #[serde(skip_serializing_if = "Option::is_none")]
136        name: Option<String>,
137    },
138    User {
139        content: String,
140        #[serde(skip_serializing_if = "Option::is_none")]
141        name: Option<String>,
142    },
143    Assistant {
144        content: String,
145        #[serde(skip_serializing_if = "Option::is_none")]
146        name: Option<String>,
147        #[serde(
148            default,
149            deserialize_with = "json_utils::null_or_vec",
150            skip_serializing_if = "Vec::is_empty"
151        )]
152        tool_calls: Vec<ToolCall>,
153    },
154    #[serde(rename = "tool")]
155    ToolResult {
156        tool_call_id: String,
157        content: String,
158    },
159}
160
161impl Message {
162    pub fn system(content: &str) -> Self {
163        Message::System {
164            content: content.to_owned(),
165            name: None,
166        }
167    }
168}
169
170impl From<message::ToolResult> for Message {
171    fn from(tool_result: message::ToolResult) -> Self {
172        let content = match tool_result.content.first() {
173            message::ToolResultContent::Text(text) => text.text,
174            message::ToolResultContent::Image(_) => String::from("[Image]"),
175        };
176
177        Message::ToolResult {
178            tool_call_id: tool_result.id,
179            content,
180        }
181    }
182}
183
184impl From<message::ToolCall> for ToolCall {
185    fn from(tool_call: message::ToolCall) -> Self {
186        Self {
187            id: tool_call.id,
188            // TODO: update index when we have it
189            index: 0,
190            r#type: ToolType::Function,
191            function: Function {
192                name: tool_call.function.name,
193                arguments: tool_call.function.arguments,
194            },
195        }
196    }
197}
198
199impl TryFrom<message::Message> for Vec<Message> {
200    type Error = message::MessageError;
201
202    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
203        match message {
204            message::Message::User { content } => {
205                // extract tool results
206                let mut messages = vec![];
207
208                let tool_results = content
209                    .clone()
210                    .into_iter()
211                    .filter_map(|content| match content {
212                        message::UserContent::ToolResult(tool_result) => {
213                            Some(Message::from(tool_result))
214                        }
215                        _ => None,
216                    })
217                    .collect::<Vec<_>>();
218
219                messages.extend(tool_results);
220
221                // extract text results
222                let text_messages = content
223                    .into_iter()
224                    .filter_map(|content| match content {
225                        message::UserContent::Text(text) => Some(Message::User {
226                            content: text.text,
227                            name: None,
228                        }),
229                        _ => None,
230                    })
231                    .collect::<Vec<_>>();
232                messages.extend(text_messages);
233
234                Ok(messages)
235            }
236            message::Message::Assistant { content } => {
237                let mut messages: Vec<Message> = vec![];
238
239                // extract tool calls
240                let tool_calls = content
241                    .clone()
242                    .into_iter()
243                    .filter_map(|content| match content {
244                        message::AssistantContent::ToolCall(tool_call) => {
245                            Some(ToolCall::from(tool_call))
246                        }
247                        _ => None,
248                    })
249                    .collect::<Vec<_>>();
250
251                // if we have tool calls, we add a new Assistant message with them
252                if !tool_calls.is_empty() {
253                    messages.push(Message::Assistant {
254                        content: "".to_string(),
255                        name: None,
256                        tool_calls,
257                    });
258                }
259
260                // extract text
261                let text_content = content
262                    .into_iter()
263                    .filter_map(|content| match content {
264                        message::AssistantContent::Text(text) => Some(Message::Assistant {
265                            content: text.text,
266                            name: None,
267                            tool_calls: vec![],
268                        }),
269                        _ => None,
270                    })
271                    .collect::<Vec<_>>();
272
273                messages.extend(text_content);
274
275                Ok(messages)
276            }
277        }
278    }
279}
280
281#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
282pub struct ToolCall {
283    pub id: String,
284    pub index: usize,
285    #[serde(default)]
286    pub r#type: ToolType,
287    pub function: Function,
288}
289
290#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
291pub struct Function {
292    pub name: String,
293    #[serde(with = "json_utils::stringified_json")]
294    pub arguments: serde_json::Value,
295}
296
297#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
298#[serde(rename_all = "lowercase")]
299pub enum ToolType {
300    #[default]
301    Function,
302}
303
304#[derive(Clone, Debug, Deserialize, Serialize)]
305pub struct ToolDefinition {
306    pub r#type: String,
307    pub function: completion::ToolDefinition,
308}
309
310impl From<crate::completion::ToolDefinition> for ToolDefinition {
311    fn from(tool: crate::completion::ToolDefinition) -> Self {
312        Self {
313            r#type: "function".into(),
314            function: tool,
315        }
316    }
317}
318
319impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
320    type Error = CompletionError;
321
322    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
323        let choice = response.choices.first().ok_or_else(|| {
324            CompletionError::ResponseError("Response contained no choices".to_owned())
325        })?;
326        let content = match &choice.message {
327            Message::Assistant {
328                content,
329                tool_calls,
330                ..
331            } => {
332                let mut content = if content.trim().is_empty() {
333                    vec![]
334                } else {
335                    vec![completion::AssistantContent::text(content)]
336                };
337
338                content.extend(
339                    tool_calls
340                        .iter()
341                        .map(|call| {
342                            completion::AssistantContent::tool_call(
343                                &call.id,
344                                &call.function.name,
345                                call.function.arguments.clone(),
346                            )
347                        })
348                        .collect::<Vec<_>>(),
349                );
350                Ok(content)
351            }
352            _ => Err(CompletionError::ResponseError(
353                "Response did not contain a valid message or tool call".into(),
354            )),
355        }?;
356
357        let choice = OneOrMany::many(content).map_err(|_| {
358            CompletionError::ResponseError(
359                "Response contained no message or tool call (empty)".to_owned(),
360            )
361        })?;
362
363        Ok(completion::CompletionResponse {
364            choice,
365            raw_response: response,
366        })
367    }
368}
369
370/// The struct implementing the `CompletionModel` trait
371#[derive(Clone)]
372pub struct DeepSeekCompletionModel {
373    pub client: Client,
374    pub model: String,
375}
376
377impl DeepSeekCompletionModel {
378    fn create_completion_request(
379        &self,
380        completion_request: CompletionRequest,
381    ) -> Result<serde_json::Value, CompletionError> {
382        // Build up the order of messages (context, chat_history, prompt)
383        let mut partial_history = vec![];
384        if let Some(docs) = completion_request.normalized_documents() {
385            partial_history.push(docs);
386        }
387        partial_history.extend(completion_request.chat_history);
388
389        // Initialize full history with preamble (or empty if non-existent)
390        let mut full_history: Vec<Message> = completion_request
391            .preamble
392            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
393
394        // Convert and extend the rest of the history
395        full_history.extend(
396            partial_history
397                .into_iter()
398                .map(message::Message::try_into)
399                .collect::<Result<Vec<Vec<Message>>, _>>()?
400                .into_iter()
401                .flatten()
402                .collect::<Vec<_>>(),
403        );
404
405        let request = if completion_request.tools.is_empty() {
406            json!({
407                "model": self.model,
408                "messages": full_history,
409                "temperature": completion_request.temperature,
410            })
411        } else {
412            json!({
413                "model": self.model,
414                "messages": full_history,
415                "temperature": completion_request.temperature,
416                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
417                "tool_choice": "auto",
418            })
419        };
420
421        let request = if let Some(params) = completion_request.additional_params {
422            json_utils::merge(request, params)
423        } else {
424            request
425        };
426
427        Ok(request)
428    }
429}
430
431impl CompletionModel for DeepSeekCompletionModel {
432    type Response = CompletionResponse;
433    type StreamingResponse = openai::StreamingCompletionResponse;
434
435    #[cfg_attr(feature = "worker", worker::send)]
436    async fn completion(
437        &self,
438        completion_request: CompletionRequest,
439    ) -> Result<
440        completion::CompletionResponse<CompletionResponse>,
441        crate::completion::CompletionError,
442    > {
443        let request = self.create_completion_request(completion_request)?;
444
445        let response = self
446            .client
447            .post("/chat/completions")
448            .json(&request)
449            .send()
450            .await?;
451
452        if response.status().is_success() {
453            let t = response.text().await?;
454            tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
455
456            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
457                ApiResponse::Ok(response) => response.try_into(),
458                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
459            }
460        } else {
461            Err(CompletionError::ProviderError(response.text().await?))
462        }
463    }
464
465    #[cfg_attr(feature = "worker", worker::send)]
466    async fn stream(
467        &self,
468        completion_request: CompletionRequest,
469    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
470        let mut request = self.create_completion_request(completion_request)?;
471
472        request = merge(
473            request,
474            json!({"stream": true, "stream_options": {"include_usage": true}}),
475        );
476
477        let builder = self.client.post("/v1/chat/completions").json(&request);
478        send_compatible_streaming_request(builder).await
479    }
480}
481
482// ================================================================
483// DeepSeek Completion API
484// ================================================================
485
486/// `deepseek-chat` completion model
487pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
488/// `deepseek-reasoner` completion model
489pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
490
491// Tests
492#[cfg(test)]
493mod tests {
494
495    use super::*;
496
497    #[test]
498    fn test_deserialize_vec_choice() {
499        let data = r#"[{
500            "finish_reason": "stop",
501            "index": 0,
502            "logprobs": null,
503            "message":{"role":"assistant","content":"Hello, world!"}
504            }]"#;
505
506        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
507        assert_eq!(choices.len(), 1);
508        match &choices.first().unwrap().message {
509            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
510            _ => panic!("Expected assistant message"),
511        }
512    }
513
514    #[test]
515    fn test_deserialize_deepseek_response() {
516        let data = r#"{"choices":[{
517            "finish_reason": "stop",
518            "index": 0,
519            "logprobs": null,
520            "message":{"role":"assistant","content":"Hello, world!"}
521            }]}"#;
522
523        let jd = &mut serde_json::Deserializer::from_str(data);
524        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
525        match result {
526            Ok(response) => match &response.choices.first().unwrap().message {
527                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
528                _ => panic!("Expected assistant message"),
529            },
530            Err(err) => {
531                panic!("Deserialization error at {}: {}", err.path(), err);
532            }
533        }
534    }
535
536    #[test]
537    fn test_deserialize_example_response() {
538        let data = r#"
539        {
540            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
541            "object": "chat.completion",
542            "created": 0,
543            "model": "deepseek-chat",
544            "choices": [
545                {
546                    "index": 0,
547                    "message": {
548                        "role": "assistant",
549                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
550                    },
551                    "logprobs": null,
552                    "finish_reason": "stop"
553                }
554            ],
555            "usage": {
556                "prompt_tokens": 13,
557                "completion_tokens": 32,
558                "total_tokens": 45,
559                "prompt_tokens_details": {
560                    "cached_tokens": 0
561                },
562                "prompt_cache_hit_tokens": 0,
563                "prompt_cache_miss_tokens": 13
564            },
565            "system_fingerprint": "fp_4b6881f2c5"
566        }
567        "#;
568        let jd = &mut serde_json::Deserializer::from_str(data);
569        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
570
571        match result {
572            Ok(response) => match &response.choices.first().unwrap().message {
573                Message::Assistant { content, .. } => assert_eq!(
574                    content,
575                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
576                ),
577                _ => panic!("Expected assistant message"),
578            },
579            Err(err) => {
580                panic!("Deserialization error at {}: {}", err.path(), err);
581            }
582        }
583    }
584
585    #[test]
586    fn test_serialize_deserialize_tool_call_message() {
587        let tool_call_choice_json = r#"
588            {
589              "finish_reason": "tool_calls",
590              "index": 0,
591              "logprobs": null,
592              "message": {
593                "content": "",
594                "role": "assistant",
595                "tool_calls": [
596                  {
597                    "function": {
598                      "arguments": "{\"x\":2,\"y\":5}",
599                      "name": "subtract"
600                    },
601                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
602                    "index": 0,
603                    "type": "function"
604                  }
605                ]
606              }
607            }
608        "#;
609
610        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
611
612        let expected_choice: Choice = Choice {
613            finish_reason: "tool_calls".to_string(),
614            index: 0,
615            logprobs: None,
616            message: Message::Assistant {
617                content: "".to_string(),
618                name: None,
619                tool_calls: vec![ToolCall {
620                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
621                    function: Function {
622                        name: "subtract".to_string(),
623                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
624                    },
625                    index: 0,
626                    r#type: ToolType::Function,
627                }],
628            },
629        };
630
631        assert_eq!(choice, expected_choice);
632    }
633}