rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::client::{CompletionClient, ProviderClient};
13use crate::json_utils::merge;
14use crate::message::Document;
15use crate::providers::openai;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19    OneOrMany,
20    completion::{self, CompletionError, CompletionModel, CompletionRequest},
21    impl_conversion_traits, json_utils, message,
22};
23use reqwest::Client as HttpClient;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26
27// ================================================================
28// Main DeepSeek Client
29// ================================================================
30const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
31
32#[derive(Clone)]
33pub struct Client {
34    pub base_url: String,
35    api_key: String,
36    http_client: HttpClient,
37}
38
39impl std::fmt::Debug for Client {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("Client")
42            .field("base_url", &self.base_url)
43            .field("http_client", &self.http_client)
44            .field("api_key", &"<REDACTED>")
45            .finish()
46    }
47}
48
49impl Client {
50    // Create a new DeepSeek client from an API key.
51    pub fn new(api_key: &str) -> Self {
52        Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
53    }
54
55    /// Set your own URL.
56    /// Useful if you need to access an alternative website that supports the DeepSeek API specification.
57    pub fn from_url(api_key: &str, base_url: &str) -> Self {
58        Self {
59            base_url: base_url.to_string(),
60            api_key: api_key.to_string(),
61            http_client: reqwest::Client::builder()
62                .build()
63                .expect("DeepSeek reqwest client should build"),
64        }
65    }
66
67    /// Use your own `reqwest::Client`.
68    /// The required headers will be automatically attached upon trying to make a request.
69    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
70        self.http_client = client;
71
72        self
73    }
74
75    fn post(&self, path: &str) -> reqwest::RequestBuilder {
76        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
77        self.http_client.post(url).bearer_auth(&self.api_key)
78    }
79}
80
81impl ProviderClient for Client {
82    // If you prefer the environment variable approach:
83    fn from_env() -> Self {
84        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
85        Self::new(&api_key)
86    }
87}
88
89impl CompletionClient for Client {
90    type CompletionModel = DeepSeekCompletionModel;
91
92    /// Creates a DeepSeek completion model with the given `model_name`.
93    fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
94        DeepSeekCompletionModel {
95            client: self.clone(),
96            model: model_name.to_string(),
97        }
98    }
99}
100
101impl_conversion_traits!(
102    AsEmbeddings,
103    AsTranscription,
104    AsImageGeneration,
105    AsAudioGeneration for Client
106);
107
108#[derive(Debug, Deserialize)]
109struct ApiErrorResponse {
110    message: String,
111}
112
113#[derive(Debug, Deserialize)]
114#[serde(untagged)]
115enum ApiResponse<T> {
116    Ok(T),
117    Err(ApiErrorResponse),
118}
119
120impl From<ApiErrorResponse> for CompletionError {
121    fn from(err: ApiErrorResponse) -> Self {
122        CompletionError::ProviderError(err.message)
123    }
124}
125
126/// The response shape from the DeepSeek API
127#[derive(Clone, Debug, Serialize, Deserialize)]
128pub struct CompletionResponse {
129    // We'll match the JSON:
130    pub choices: Vec<Choice>,
131    // you may want usage or other fields
132}
133
134#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
135pub struct Choice {
136    pub index: usize,
137    pub message: Message,
138    pub logprobs: Option<serde_json::Value>,
139    pub finish_reason: String,
140}
141
142#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
143#[serde(tag = "role", rename_all = "lowercase")]
144pub enum Message {
145    System {
146        content: String,
147        #[serde(skip_serializing_if = "Option::is_none")]
148        name: Option<String>,
149    },
150    User {
151        content: String,
152        #[serde(skip_serializing_if = "Option::is_none")]
153        name: Option<String>,
154    },
155    Assistant {
156        content: String,
157        #[serde(skip_serializing_if = "Option::is_none")]
158        name: Option<String>,
159        #[serde(
160            default,
161            deserialize_with = "json_utils::null_or_vec",
162            skip_serializing_if = "Vec::is_empty"
163        )]
164        tool_calls: Vec<ToolCall>,
165    },
166    #[serde(rename = "tool")]
167    ToolResult {
168        tool_call_id: String,
169        content: String,
170    },
171}
172
173impl Message {
174    pub fn system(content: &str) -> Self {
175        Message::System {
176            content: content.to_owned(),
177            name: None,
178        }
179    }
180}
181
182impl From<message::ToolResult> for Message {
183    fn from(tool_result: message::ToolResult) -> Self {
184        let content = match tool_result.content.first() {
185            message::ToolResultContent::Text(text) => text.text,
186            message::ToolResultContent::Image(_) => String::from("[Image]"),
187        };
188
189        Message::ToolResult {
190            tool_call_id: tool_result.id,
191            content,
192        }
193    }
194}
195
196impl From<message::ToolCall> for ToolCall {
197    fn from(tool_call: message::ToolCall) -> Self {
198        Self {
199            id: tool_call.id,
200            // TODO: update index when we have it
201            index: 0,
202            r#type: ToolType::Function,
203            function: Function {
204                name: tool_call.function.name,
205                arguments: tool_call.function.arguments,
206            },
207        }
208    }
209}
210
211impl TryFrom<message::Message> for Vec<Message> {
212    type Error = message::MessageError;
213
214    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
215        match message {
216            message::Message::User { content } => {
217                // extract tool results
218                let mut messages = vec![];
219
220                let tool_results = content
221                    .clone()
222                    .into_iter()
223                    .filter_map(|content| match content {
224                        message::UserContent::ToolResult(tool_result) => {
225                            Some(Message::from(tool_result))
226                        }
227                        _ => None,
228                    })
229                    .collect::<Vec<_>>();
230
231                messages.extend(tool_results);
232
233                // extract text results
234                let text_messages = content
235                    .into_iter()
236                    .filter_map(|content| match content {
237                        message::UserContent::Text(text) => Some(Message::User {
238                            content: text.text,
239                            name: None,
240                        }),
241                        message::UserContent::Document(Document { data, .. }) => {
242                            Some(Message::User {
243                                content: data,
244                                name: None,
245                            })
246                        }
247                        _ => None,
248                    })
249                    .collect::<Vec<_>>();
250                messages.extend(text_messages);
251
252                Ok(messages)
253            }
254            message::Message::Assistant { content, .. } => {
255                let mut messages: Vec<Message> = vec![];
256
257                // extract tool calls
258                let tool_calls = content
259                    .clone()
260                    .into_iter()
261                    .filter_map(|content| match content {
262                        message::AssistantContent::ToolCall(tool_call) => {
263                            Some(ToolCall::from(tool_call))
264                        }
265                        _ => None,
266                    })
267                    .collect::<Vec<_>>();
268
269                // if we have tool calls, we add a new Assistant message with them
270                if !tool_calls.is_empty() {
271                    messages.push(Message::Assistant {
272                        content: "".to_string(),
273                        name: None,
274                        tool_calls,
275                    });
276                }
277
278                // extract text
279                let text_content = content
280                    .into_iter()
281                    .filter_map(|content| match content {
282                        message::AssistantContent::Text(text) => Some(Message::Assistant {
283                            content: text.text,
284                            name: None,
285                            tool_calls: vec![],
286                        }),
287                        _ => None,
288                    })
289                    .collect::<Vec<_>>();
290
291                messages.extend(text_content);
292
293                Ok(messages)
294            }
295        }
296    }
297}
298
299#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
300pub struct ToolCall {
301    pub id: String,
302    pub index: usize,
303    #[serde(default)]
304    pub r#type: ToolType,
305    pub function: Function,
306}
307
308#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
309pub struct Function {
310    pub name: String,
311    #[serde(with = "json_utils::stringified_json")]
312    pub arguments: serde_json::Value,
313}
314
315#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
316#[serde(rename_all = "lowercase")]
317pub enum ToolType {
318    #[default]
319    Function,
320}
321
322#[derive(Clone, Debug, Deserialize, Serialize)]
323pub struct ToolDefinition {
324    pub r#type: String,
325    pub function: completion::ToolDefinition,
326}
327
328impl From<crate::completion::ToolDefinition> for ToolDefinition {
329    fn from(tool: crate::completion::ToolDefinition) -> Self {
330        Self {
331            r#type: "function".into(),
332            function: tool,
333        }
334    }
335}
336
337impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
338    type Error = CompletionError;
339
340    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
341        let choice = response.choices.first().ok_or_else(|| {
342            CompletionError::ResponseError("Response contained no choices".to_owned())
343        })?;
344        let content = match &choice.message {
345            Message::Assistant {
346                content,
347                tool_calls,
348                ..
349            } => {
350                let mut content = if content.trim().is_empty() {
351                    vec![]
352                } else {
353                    vec![completion::AssistantContent::text(content)]
354                };
355
356                content.extend(
357                    tool_calls
358                        .iter()
359                        .map(|call| {
360                            completion::AssistantContent::tool_call(
361                                &call.id,
362                                &call.function.name,
363                                call.function.arguments.clone(),
364                            )
365                        })
366                        .collect::<Vec<_>>(),
367                );
368                Ok(content)
369            }
370            _ => Err(CompletionError::ResponseError(
371                "Response did not contain a valid message or tool call".into(),
372            )),
373        }?;
374
375        let choice = OneOrMany::many(content).map_err(|_| {
376            CompletionError::ResponseError(
377                "Response contained no message or tool call (empty)".to_owned(),
378            )
379        })?;
380
381        Ok(completion::CompletionResponse {
382            choice,
383            raw_response: response,
384        })
385    }
386}
387
388/// The struct implementing the `CompletionModel` trait
389#[derive(Clone)]
390pub struct DeepSeekCompletionModel {
391    pub client: Client,
392    pub model: String,
393}
394
395impl DeepSeekCompletionModel {
396    fn create_completion_request(
397        &self,
398        completion_request: CompletionRequest,
399    ) -> Result<serde_json::Value, CompletionError> {
400        // Build up the order of messages (context, chat_history, prompt)
401        let mut partial_history = vec![];
402
403        if let Some(docs) = completion_request.normalized_documents() {
404            partial_history.push(docs);
405        }
406
407        partial_history.extend(completion_request.chat_history);
408
409        // Initialize full history with preamble (or empty if non-existent)
410        let mut full_history: Vec<Message> = completion_request
411            .preamble
412            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
413
414        // Convert and extend the rest of the history
415        full_history.extend(
416            partial_history
417                .into_iter()
418                .map(message::Message::try_into)
419                .collect::<Result<Vec<Vec<Message>>, _>>()?
420                .into_iter()
421                .flatten()
422                .collect::<Vec<_>>(),
423        );
424
425        let request = if completion_request.tools.is_empty() {
426            json!({
427                "model": self.model,
428                "messages": full_history,
429                "temperature": completion_request.temperature,
430            })
431        } else {
432            json!({
433                "model": self.model,
434                "messages": full_history,
435                "temperature": completion_request.temperature,
436                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
437                "tool_choice": "auto",
438            })
439        };
440
441        let request = if let Some(params) = completion_request.additional_params {
442            json_utils::merge(request, params)
443        } else {
444            request
445        };
446
447        Ok(request)
448    }
449}
450
451impl CompletionModel for DeepSeekCompletionModel {
452    type Response = CompletionResponse;
453    type StreamingResponse = openai::StreamingCompletionResponse;
454
455    #[cfg_attr(feature = "worker", worker::send)]
456    async fn completion(
457        &self,
458        completion_request: CompletionRequest,
459    ) -> Result<
460        completion::CompletionResponse<CompletionResponse>,
461        crate::completion::CompletionError,
462    > {
463        let request = self.create_completion_request(completion_request)?;
464
465        tracing::debug!("DeepSeek completion request: {request:?}");
466
467        let response = self
468            .client
469            .post("/chat/completions")
470            .json(&request)
471            .send()
472            .await?;
473
474        if response.status().is_success() {
475            let t = response.text().await?;
476            tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
477
478            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
479                ApiResponse::Ok(response) => response.try_into(),
480                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
481            }
482        } else {
483            Err(CompletionError::ProviderError(response.text().await?))
484        }
485    }
486
487    #[cfg_attr(feature = "worker", worker::send)]
488    async fn stream(
489        &self,
490        completion_request: CompletionRequest,
491    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
492        let mut request = self.create_completion_request(completion_request)?;
493
494        request = merge(
495            request,
496            json!({"stream": true, "stream_options": {"include_usage": true}}),
497        );
498
499        let builder = self.client.post("/v1/chat/completions").json(&request);
500        send_compatible_streaming_request(builder).await
501    }
502}
503
504// ================================================================
505// DeepSeek Completion API
506// ================================================================
507
508/// `deepseek-chat` completion model
509pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
510/// `deepseek-reasoner` completion model
511pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
512
513// Tests
514#[cfg(test)]
515mod tests {
516
517    use super::*;
518
519    #[test]
520    fn test_deserialize_vec_choice() {
521        let data = r#"[{
522            "finish_reason": "stop",
523            "index": 0,
524            "logprobs": null,
525            "message":{"role":"assistant","content":"Hello, world!"}
526            }]"#;
527
528        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
529        assert_eq!(choices.len(), 1);
530        match &choices.first().unwrap().message {
531            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
532            _ => panic!("Expected assistant message"),
533        }
534    }
535
536    #[test]
537    fn test_deserialize_deepseek_response() {
538        let data = r#"{"choices":[{
539            "finish_reason": "stop",
540            "index": 0,
541            "logprobs": null,
542            "message":{"role":"assistant","content":"Hello, world!"}
543            }]}"#;
544
545        let jd = &mut serde_json::Deserializer::from_str(data);
546        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
547        match result {
548            Ok(response) => match &response.choices.first().unwrap().message {
549                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
550                _ => panic!("Expected assistant message"),
551            },
552            Err(err) => {
553                panic!("Deserialization error at {}: {}", err.path(), err);
554            }
555        }
556    }
557
558    #[test]
559    fn test_deserialize_example_response() {
560        let data = r#"
561        {
562            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
563            "object": "chat.completion",
564            "created": 0,
565            "model": "deepseek-chat",
566            "choices": [
567                {
568                    "index": 0,
569                    "message": {
570                        "role": "assistant",
571                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
572                    },
573                    "logprobs": null,
574                    "finish_reason": "stop"
575                }
576            ],
577            "usage": {
578                "prompt_tokens": 13,
579                "completion_tokens": 32,
580                "total_tokens": 45,
581                "prompt_tokens_details": {
582                    "cached_tokens": 0
583                },
584                "prompt_cache_hit_tokens": 0,
585                "prompt_cache_miss_tokens": 13
586            },
587            "system_fingerprint": "fp_4b6881f2c5"
588        }
589        "#;
590        let jd = &mut serde_json::Deserializer::from_str(data);
591        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
592
593        match result {
594            Ok(response) => match &response.choices.first().unwrap().message {
595                Message::Assistant { content, .. } => assert_eq!(
596                    content,
597                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
598                ),
599                _ => panic!("Expected assistant message"),
600            },
601            Err(err) => {
602                panic!("Deserialization error at {}: {}", err.path(), err);
603            }
604        }
605    }
606
607    #[test]
608    fn test_serialize_deserialize_tool_call_message() {
609        let tool_call_choice_json = r#"
610            {
611              "finish_reason": "tool_calls",
612              "index": 0,
613              "logprobs": null,
614              "message": {
615                "content": "",
616                "role": "assistant",
617                "tool_calls": [
618                  {
619                    "function": {
620                      "arguments": "{\"x\":2,\"y\":5}",
621                      "name": "subtract"
622                    },
623                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
624                    "index": 0,
625                    "type": "function"
626                  }
627                ]
628              }
629            }
630        "#;
631
632        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
633
634        let expected_choice: Choice = Choice {
635            finish_reason: "tool_calls".to_string(),
636            index: 0,
637            logprobs: None,
638            message: Message::Assistant {
639                content: "".to_string(),
640                name: None,
641                tool_calls: vec![ToolCall {
642                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
643                    function: Function {
644                        name: "subtract".to_string(),
645                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
646                    },
647                    index: 0,
648                    r#type: ToolType::Function,
649                }],
650            },
651        };
652
653        assert_eq!(choice, expected_choice);
654    }
655}