rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::client::{CompletionClient, ProviderClient};
13use crate::json_utils::merge;
14use crate::message::Document;
15use crate::providers::openai;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19    OneOrMany,
20    completion::{self, CompletionError, CompletionModel, CompletionRequest},
21    impl_conversion_traits, json_utils, message,
22};
23use reqwest::Client as HttpClient;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26
27// ================================================================
28// Main DeepSeek Client
29// ================================================================
30const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
31
32#[derive(Clone)]
33pub struct Client {
34    pub base_url: String,
35    api_key: String,
36    http_client: HttpClient,
37}
38
39impl std::fmt::Debug for Client {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("Client")
42            .field("base_url", &self.base_url)
43            .field("http_client", &self.http_client)
44            .field("api_key", &"<REDACTED>")
45            .finish()
46    }
47}
48
49impl Client {
50    // Create a new DeepSeek client from an API key.
51    pub fn new(api_key: &str) -> Self {
52        Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
53    }
54
55    /// Set your own URL.
56    /// Useful if you need to access an alternative website that supports the DeepSeek API specification.
57    pub fn from_url(api_key: &str, base_url: &str) -> Self {
58        Self {
59            base_url: base_url.to_string(),
60            api_key: api_key.to_string(),
61            http_client: reqwest::Client::builder()
62                .build()
63                .expect("DeepSeek reqwest client should build"),
64        }
65    }
66
67    /// Use your own `reqwest::Client`.
68    /// The required headers will be automatically attached upon trying to make a request.
69    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
70        self.http_client = client;
71
72        self
73    }
74
75    fn post(&self, path: &str) -> reqwest::RequestBuilder {
76        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
77        self.http_client.post(url).bearer_auth(&self.api_key)
78    }
79}
80
81impl ProviderClient for Client {
82    // If you prefer the environment variable approach:
83    fn from_env() -> Self {
84        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
85        Self::new(&api_key)
86    }
87
88    fn from_val(input: crate::client::ProviderValue) -> Self {
89        let crate::client::ProviderValue::Simple(api_key) = input else {
90            panic!("Incorrect provider value type")
91        };
92        Self::new(&api_key)
93    }
94}
95
96impl CompletionClient for Client {
97    type CompletionModel = DeepSeekCompletionModel;
98
99    /// Creates a DeepSeek completion model with the given `model_name`.
100    fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
101        DeepSeekCompletionModel {
102            client: self.clone(),
103            model: model_name.to_string(),
104        }
105    }
106}
107
108impl_conversion_traits!(
109    AsEmbeddings,
110    AsTranscription,
111    AsImageGeneration,
112    AsAudioGeneration for Client
113);
114
115#[derive(Debug, Deserialize)]
116struct ApiErrorResponse {
117    message: String,
118}
119
120#[derive(Debug, Deserialize)]
121#[serde(untagged)]
122enum ApiResponse<T> {
123    Ok(T),
124    Err(ApiErrorResponse),
125}
126
127impl From<ApiErrorResponse> for CompletionError {
128    fn from(err: ApiErrorResponse) -> Self {
129        CompletionError::ProviderError(err.message)
130    }
131}
132
133/// The response shape from the DeepSeek API
134#[derive(Clone, Debug, Serialize, Deserialize)]
135pub struct CompletionResponse {
136    // We'll match the JSON:
137    pub choices: Vec<Choice>,
138    pub usage: Usage,
139    // you may want other fields
140}
141
142#[derive(Clone, Debug, Serialize, Deserialize, Default)]
143pub struct Usage {
144    pub completion_tokens: u32,
145    pub prompt_tokens: u32,
146    pub prompt_cache_hit_tokens: u32,
147    pub prompt_cache_miss_tokens: u32,
148    pub total_tokens: u32,
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub completion_tokens_details: Option<CompletionTokensDetails>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub prompt_tokens_details: Option<PromptTokensDetails>,
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize, Default)]
156pub struct CompletionTokensDetails {
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub reasoning_tokens: Option<u32>,
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct PromptTokensDetails {
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub cached_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
168pub struct Choice {
169    pub index: usize,
170    pub message: Message,
171    pub logprobs: Option<serde_json::Value>,
172    pub finish_reason: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(tag = "role", rename_all = "lowercase")]
177pub enum Message {
178    System {
179        content: String,
180        #[serde(skip_serializing_if = "Option::is_none")]
181        name: Option<String>,
182    },
183    User {
184        content: String,
185        #[serde(skip_serializing_if = "Option::is_none")]
186        name: Option<String>,
187    },
188    Assistant {
189        content: String,
190        #[serde(skip_serializing_if = "Option::is_none")]
191        name: Option<String>,
192        #[serde(
193            default,
194            deserialize_with = "json_utils::null_or_vec",
195            skip_serializing_if = "Vec::is_empty"
196        )]
197        tool_calls: Vec<ToolCall>,
198    },
199    #[serde(rename = "tool")]
200    ToolResult {
201        tool_call_id: String,
202        content: String,
203    },
204}
205
206impl Message {
207    pub fn system(content: &str) -> Self {
208        Message::System {
209            content: content.to_owned(),
210            name: None,
211        }
212    }
213}
214
215impl From<message::ToolResult> for Message {
216    fn from(tool_result: message::ToolResult) -> Self {
217        let content = match tool_result.content.first() {
218            message::ToolResultContent::Text(text) => text.text,
219            message::ToolResultContent::Image(_) => String::from("[Image]"),
220        };
221
222        Message::ToolResult {
223            tool_call_id: tool_result.id,
224            content,
225        }
226    }
227}
228
229impl From<message::ToolCall> for ToolCall {
230    fn from(tool_call: message::ToolCall) -> Self {
231        Self {
232            id: tool_call.id,
233            // TODO: update index when we have it
234            index: 0,
235            r#type: ToolType::Function,
236            function: Function {
237                name: tool_call.function.name,
238                arguments: tool_call.function.arguments,
239            },
240        }
241    }
242}
243
244impl TryFrom<message::Message> for Vec<Message> {
245    type Error = message::MessageError;
246
247    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248        match message {
249            message::Message::User { content } => {
250                // extract tool results
251                let mut messages = vec![];
252
253                let tool_results = content
254                    .clone()
255                    .into_iter()
256                    .filter_map(|content| match content {
257                        message::UserContent::ToolResult(tool_result) => {
258                            Some(Message::from(tool_result))
259                        }
260                        _ => None,
261                    })
262                    .collect::<Vec<_>>();
263
264                messages.extend(tool_results);
265
266                // extract text results
267                let text_messages = content
268                    .into_iter()
269                    .filter_map(|content| match content {
270                        message::UserContent::Text(text) => Some(Message::User {
271                            content: text.text,
272                            name: None,
273                        }),
274                        message::UserContent::Document(Document { data, .. }) => {
275                            Some(Message::User {
276                                content: data,
277                                name: None,
278                            })
279                        }
280                        _ => None,
281                    })
282                    .collect::<Vec<_>>();
283                messages.extend(text_messages);
284
285                Ok(messages)
286            }
287            message::Message::Assistant { content, .. } => {
288                let mut messages: Vec<Message> = vec![];
289
290                // extract tool calls
291                let tool_calls = content
292                    .clone()
293                    .into_iter()
294                    .filter_map(|content| match content {
295                        message::AssistantContent::ToolCall(tool_call) => {
296                            Some(ToolCall::from(tool_call))
297                        }
298                        _ => None,
299                    })
300                    .collect::<Vec<_>>();
301
302                // if we have tool calls, we add a new Assistant message with them
303                if !tool_calls.is_empty() {
304                    messages.push(Message::Assistant {
305                        content: "".to_string(),
306                        name: None,
307                        tool_calls,
308                    });
309                }
310
311                // extract text
312                let text_content = content
313                    .into_iter()
314                    .filter_map(|content| match content {
315                        message::AssistantContent::Text(text) => Some(Message::Assistant {
316                            content: text.text,
317                            name: None,
318                            tool_calls: vec![],
319                        }),
320                        _ => None,
321                    })
322                    .collect::<Vec<_>>();
323
324                messages.extend(text_content);
325
326                Ok(messages)
327            }
328        }
329    }
330}
331
332#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
333pub struct ToolCall {
334    pub id: String,
335    pub index: usize,
336    #[serde(default)]
337    pub r#type: ToolType,
338    pub function: Function,
339}
340
341#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
342pub struct Function {
343    pub name: String,
344    #[serde(with = "json_utils::stringified_json")]
345    pub arguments: serde_json::Value,
346}
347
348#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
349#[serde(rename_all = "lowercase")]
350pub enum ToolType {
351    #[default]
352    Function,
353}
354
355#[derive(Clone, Debug, Deserialize, Serialize)]
356pub struct ToolDefinition {
357    pub r#type: String,
358    pub function: completion::ToolDefinition,
359}
360
361impl From<crate::completion::ToolDefinition> for ToolDefinition {
362    fn from(tool: crate::completion::ToolDefinition) -> Self {
363        Self {
364            r#type: "function".into(),
365            function: tool,
366        }
367    }
368}
369
370impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
371    type Error = CompletionError;
372
373    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
374        let choice = response.choices.first().ok_or_else(|| {
375            CompletionError::ResponseError("Response contained no choices".to_owned())
376        })?;
377        let content = match &choice.message {
378            Message::Assistant {
379                content,
380                tool_calls,
381                ..
382            } => {
383                let mut content = if content.trim().is_empty() {
384                    vec![]
385                } else {
386                    vec![completion::AssistantContent::text(content)]
387                };
388
389                content.extend(
390                    tool_calls
391                        .iter()
392                        .map(|call| {
393                            completion::AssistantContent::tool_call(
394                                &call.id,
395                                &call.function.name,
396                                call.function.arguments.clone(),
397                            )
398                        })
399                        .collect::<Vec<_>>(),
400                );
401                Ok(content)
402            }
403            _ => Err(CompletionError::ResponseError(
404                "Response did not contain a valid message or tool call".into(),
405            )),
406        }?;
407
408        let choice = OneOrMany::many(content).map_err(|_| {
409            CompletionError::ResponseError(
410                "Response contained no message or tool call (empty)".to_owned(),
411            )
412        })?;
413
414        let usage = completion::Usage {
415            input_tokens: response.usage.prompt_tokens as u64,
416            output_tokens: response.usage.completion_tokens as u64,
417            total_tokens: response.usage.total_tokens as u64,
418        };
419
420        Ok(completion::CompletionResponse {
421            choice,
422            usage,
423            raw_response: response,
424        })
425    }
426}
427
428/// The struct implementing the `CompletionModel` trait
429#[derive(Clone)]
430pub struct DeepSeekCompletionModel {
431    pub client: Client,
432    pub model: String,
433}
434
435impl DeepSeekCompletionModel {
436    fn create_completion_request(
437        &self,
438        completion_request: CompletionRequest,
439    ) -> Result<serde_json::Value, CompletionError> {
440        // Build up the order of messages (context, chat_history, prompt)
441        let mut partial_history = vec![];
442
443        if let Some(docs) = completion_request.normalized_documents() {
444            partial_history.push(docs);
445        }
446
447        partial_history.extend(completion_request.chat_history);
448
449        // Initialize full history with preamble (or empty if non-existent)
450        let mut full_history: Vec<Message> = completion_request
451            .preamble
452            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
453
454        // Convert and extend the rest of the history
455        full_history.extend(
456            partial_history
457                .into_iter()
458                .map(message::Message::try_into)
459                .collect::<Result<Vec<Vec<Message>>, _>>()?
460                .into_iter()
461                .flatten()
462                .collect::<Vec<_>>(),
463        );
464
465        let request = if completion_request.tools.is_empty() {
466            json!({
467                "model": self.model,
468                "messages": full_history,
469                "temperature": completion_request.temperature,
470            })
471        } else {
472            json!({
473                "model": self.model,
474                "messages": full_history,
475                "temperature": completion_request.temperature,
476                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
477                "tool_choice": "auto",
478            })
479        };
480
481        let request = if let Some(params) = completion_request.additional_params {
482            json_utils::merge(request, params)
483        } else {
484            request
485        };
486
487        Ok(request)
488    }
489}
490
491impl CompletionModel for DeepSeekCompletionModel {
492    type Response = CompletionResponse;
493    type StreamingResponse = openai::StreamingCompletionResponse;
494
495    #[cfg_attr(feature = "worker", worker::send)]
496    async fn completion(
497        &self,
498        completion_request: CompletionRequest,
499    ) -> Result<
500        completion::CompletionResponse<CompletionResponse>,
501        crate::completion::CompletionError,
502    > {
503        let request = self.create_completion_request(completion_request)?;
504
505        tracing::debug!("DeepSeek completion request: {request:?}");
506
507        let response = self
508            .client
509            .post("/chat/completions")
510            .json(&request)
511            .send()
512            .await?;
513
514        if response.status().is_success() {
515            let t = response.text().await?;
516            tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
517
518            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
519                ApiResponse::Ok(response) => response.try_into(),
520                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
521            }
522        } else {
523            Err(CompletionError::ProviderError(response.text().await?))
524        }
525    }
526
527    #[cfg_attr(feature = "worker", worker::send)]
528    async fn stream(
529        &self,
530        completion_request: CompletionRequest,
531    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
532        let mut request = self.create_completion_request(completion_request)?;
533
534        request = merge(
535            request,
536            json!({"stream": true, "stream_options": {"include_usage": true}}),
537        );
538
539        let builder = self.client.post("/v1/chat/completions").json(&request);
540        send_compatible_streaming_request(builder).await
541    }
542}
543
544// ================================================================
545// DeepSeek Completion API
546// ================================================================
547
548/// `deepseek-chat` completion model
549pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
550/// `deepseek-reasoner` completion model
551pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
552
553// Tests
554#[cfg(test)]
555mod tests {
556
557    use super::*;
558
559    #[test]
560    fn test_deserialize_vec_choice() {
561        let data = r#"[{
562            "finish_reason": "stop",
563            "index": 0,
564            "logprobs": null,
565            "message":{"role":"assistant","content":"Hello, world!"}
566            }]"#;
567
568        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
569        assert_eq!(choices.len(), 1);
570        match &choices.first().unwrap().message {
571            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
572            _ => panic!("Expected assistant message"),
573        }
574    }
575
576    #[test]
577    fn test_deserialize_deepseek_response() {
578        let data = r#"{
579            "choices":[{
580                "finish_reason": "stop",
581                "index": 0,
582                "logprobs": null,
583                "message":{"role":"assistant","content":"Hello, world!"}
584            }],
585            "usage": {
586                "completion_tokens": 0,
587                "prompt_tokens": 0,
588                "prompt_cache_hit_tokens": 0,
589                "prompt_cache_miss_tokens": 0,
590                "total_tokens": 0
591            }
592        }"#;
593
594        let jd = &mut serde_json::Deserializer::from_str(data);
595        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
596        match result {
597            Ok(response) => match &response.choices.first().unwrap().message {
598                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
599                _ => panic!("Expected assistant message"),
600            },
601            Err(err) => {
602                panic!("Deserialization error at {}: {}", err.path(), err);
603            }
604        }
605    }
606
607    #[test]
608    fn test_deserialize_example_response() {
609        let data = r#"
610        {
611            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
612            "object": "chat.completion",
613            "created": 0,
614            "model": "deepseek-chat",
615            "choices": [
616                {
617                    "index": 0,
618                    "message": {
619                        "role": "assistant",
620                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
621                    },
622                    "logprobs": null,
623                    "finish_reason": "stop"
624                }
625            ],
626            "usage": {
627                "prompt_tokens": 13,
628                "completion_tokens": 32,
629                "total_tokens": 45,
630                "prompt_tokens_details": {
631                    "cached_tokens": 0
632                },
633                "prompt_cache_hit_tokens": 0,
634                "prompt_cache_miss_tokens": 13
635            },
636            "system_fingerprint": "fp_4b6881f2c5"
637        }
638        "#;
639        let jd = &mut serde_json::Deserializer::from_str(data);
640        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
641
642        match result {
643            Ok(response) => match &response.choices.first().unwrap().message {
644                Message::Assistant { content, .. } => assert_eq!(
645                    content,
646                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
647                ),
648                _ => panic!("Expected assistant message"),
649            },
650            Err(err) => {
651                panic!("Deserialization error at {}: {}", err.path(), err);
652            }
653        }
654    }
655
656    #[test]
657    fn test_serialize_deserialize_tool_call_message() {
658        let tool_call_choice_json = r#"
659            {
660              "finish_reason": "tool_calls",
661              "index": 0,
662              "logprobs": null,
663              "message": {
664                "content": "",
665                "role": "assistant",
666                "tool_calls": [
667                  {
668                    "function": {
669                      "arguments": "{\"x\":2,\"y\":5}",
670                      "name": "subtract"
671                    },
672                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
673                    "index": 0,
674                    "type": "function"
675                  }
676                ]
677              }
678            }
679        "#;
680
681        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
682
683        let expected_choice: Choice = Choice {
684            finish_reason: "tool_calls".to_string(),
685            index: 0,
686            logprobs: None,
687            message: Message::Assistant {
688                content: "".to_string(),
689                name: None,
690                tool_calls: vec![ToolCall {
691                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
692                    function: Function {
693                        name: "subtract".to_string(),
694                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
695                    },
696                    index: 0,
697                    r#type: ToolType::Function,
698                }],
699            },
700        };
701
702        assert_eq!(choice, expected_choice);
703    }
704}