rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::merge;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::{StreamingCompletionModel, StreamingResult};
15use crate::{
16    completion::{self, CompletionError, CompletionModel, CompletionRequest},
17    extractor::ExtractorBuilder,
18    json_utils, message, OneOrMany,
19};
20use reqwest::Client as HttpClient;
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25// ================================================================
26// Main DeepSeek Client
27// ================================================================
28const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
29
30#[derive(Clone)]
31pub struct Client {
32    pub base_url: String,
33    http_client: HttpClient,
34}
35
36impl Client {
37    // Create a new DeepSeek client from an API key.
38    pub fn new(api_key: &str) -> Self {
39        Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
40    }
41
42    // If you prefer the environment variable approach:
43    pub fn from_env() -> Self {
44        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
45        Self::new(&api_key)
46    }
47
48    // Handy for advanced usage, e.g. letting user override base_url or set timeouts:
49    pub fn from_url(api_key: &str, base_url: &str) -> Self {
50        // Possibly configure a custom HTTP client here if needed.
51        Self {
52            base_url: base_url.to_string(),
53            http_client: reqwest::Client::builder()
54                .default_headers({
55                    let mut headers = reqwest::header::HeaderMap::new();
56                    headers.insert(
57                        "Authorization",
58                        format!("Bearer {}", api_key)
59                            .parse()
60                            .expect("Bearer token should parse"),
61                    );
62                    headers
63                })
64                .build()
65                .expect("DeepSeek reqwest client should build"),
66        }
67    }
68
69    fn post(&self, path: &str) -> reqwest::RequestBuilder {
70        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
71        self.http_client.post(url)
72    }
73
74    /// Creates a DeepSeek completion model with the given `model_name`.
75    pub fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
76        DeepSeekCompletionModel {
77            client: self.clone(),
78            model: model_name.to_string(),
79        }
80    }
81
82    /// Optionally add an agent() convenience:
83    pub fn agent(&self, model_name: &str) -> crate::agent::AgentBuilder<DeepSeekCompletionModel> {
84        crate::agent::AgentBuilder::new(self.completion_model(model_name))
85    }
86
87    /// Create an extractor builder with the given completion model.
88    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
89        &self,
90        model: &str,
91    ) -> ExtractorBuilder<T, DeepSeekCompletionModel> {
92        ExtractorBuilder::new(self.completion_model(model))
93    }
94}
95
96#[derive(Debug, Deserialize)]
97struct ApiErrorResponse {
98    message: String,
99}
100
101#[derive(Debug, Deserialize)]
102#[serde(untagged)]
103enum ApiResponse<T> {
104    Ok(T),
105    Err(ApiErrorResponse),
106}
107
108impl From<ApiErrorResponse> for CompletionError {
109    fn from(err: ApiErrorResponse) -> Self {
110        CompletionError::ProviderError(err.message)
111    }
112}
113
114/// The response shape from the DeepSeek API
115#[derive(Clone, Debug, Serialize, Deserialize)]
116pub struct CompletionResponse {
117    // We'll match the JSON:
118    pub choices: Vec<Choice>,
119    // you may want usage or other fields
120}
121
122#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
123pub struct Choice {
124    pub index: usize,
125    pub message: Message,
126    pub logprobs: Option<serde_json::Value>,
127    pub finish_reason: String,
128}
129
130#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
131#[serde(tag = "role", rename_all = "lowercase")]
132pub enum Message {
133    System {
134        content: String,
135        #[serde(skip_serializing_if = "Option::is_none")]
136        name: Option<String>,
137    },
138    User {
139        content: String,
140        #[serde(skip_serializing_if = "Option::is_none")]
141        name: Option<String>,
142    },
143    Assistant {
144        content: String,
145        #[serde(skip_serializing_if = "Option::is_none")]
146        name: Option<String>,
147        #[serde(
148            default,
149            deserialize_with = "json_utils::null_or_vec",
150            skip_serializing_if = "Vec::is_empty"
151        )]
152        tool_calls: Vec<ToolCall>,
153    },
154    #[serde(rename = "tool")]
155    ToolResult {
156        tool_call_id: String,
157        content: String,
158    },
159}
160
161impl Message {
162    pub fn system(content: &str) -> Self {
163        Message::System {
164            content: content.to_owned(),
165            name: None,
166        }
167    }
168}
169
170impl From<message::ToolResult> for Message {
171    fn from(tool_result: message::ToolResult) -> Self {
172        let content = match tool_result.content.first() {
173            message::ToolResultContent::Text(text) => text.text,
174            message::ToolResultContent::Image(_) => String::from("[Image]"),
175        };
176
177        Message::ToolResult {
178            tool_call_id: tool_result.id,
179            content,
180        }
181    }
182}
183
184impl From<message::ToolCall> for ToolCall {
185    fn from(tool_call: message::ToolCall) -> Self {
186        Self {
187            id: tool_call.id,
188            // TODO: update index when we have it
189            index: 0,
190            r#type: ToolType::Function,
191            function: Function {
192                name: tool_call.function.name,
193                arguments: tool_call.function.arguments,
194            },
195        }
196    }
197}
198
199impl TryFrom<message::Message> for Vec<Message> {
200    type Error = message::MessageError;
201
202    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
203        match message {
204            message::Message::User { content } => {
205                // extract tool results
206                let mut messages = vec![];
207
208                let tool_results = content
209                    .clone()
210                    .into_iter()
211                    .filter_map(|content| match content {
212                        message::UserContent::ToolResult(tool_result) => {
213                            Some(Message::from(tool_result))
214                        }
215                        _ => None,
216                    })
217                    .collect::<Vec<_>>();
218
219                messages.extend(tool_results);
220
221                // extract text results
222                let text_messages = content
223                    .into_iter()
224                    .filter_map(|content| match content {
225                        message::UserContent::Text(text) => Some(Message::User {
226                            content: text.text,
227                            name: None,
228                        }),
229                        _ => None,
230                    })
231                    .collect::<Vec<_>>();
232                messages.extend(text_messages);
233
234                Ok(messages)
235            }
236            message::Message::Assistant { content } => {
237                let mut messages: Vec<Message> = vec![];
238
239                // extract tool calls
240                let tool_calls = content
241                    .clone()
242                    .into_iter()
243                    .filter_map(|content| match content {
244                        message::AssistantContent::ToolCall(tool_call) => {
245                            Some(ToolCall::from(tool_call))
246                        }
247                        _ => None,
248                    })
249                    .collect::<Vec<_>>();
250
251                // if we have tool calls, we add a new Assistant message with them
252                if !tool_calls.is_empty() {
253                    messages.push(Message::Assistant {
254                        content: "".to_string(),
255                        name: None,
256                        tool_calls,
257                    });
258                }
259
260                // extract text
261                let text_content = content
262                    .into_iter()
263                    .filter_map(|content| match content {
264                        message::AssistantContent::Text(text) => Some(Message::Assistant {
265                            content: text.text,
266                            name: None,
267                            tool_calls: vec![],
268                        }),
269                        _ => None,
270                    })
271                    .collect::<Vec<_>>();
272
273                messages.extend(text_content);
274
275                Ok(messages)
276            }
277        }
278    }
279}
280
281#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
282pub struct ToolCall {
283    pub id: String,
284    pub index: usize,
285    #[serde(default)]
286    pub r#type: ToolType,
287    pub function: Function,
288}
289
290#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
291pub struct Function {
292    pub name: String,
293    #[serde(with = "json_utils::stringified_json")]
294    pub arguments: serde_json::Value,
295}
296
297#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
298#[serde(rename_all = "lowercase")]
299pub enum ToolType {
300    #[default]
301    Function,
302}
303
304#[derive(Clone, Debug, Deserialize, Serialize)]
305pub struct ToolDefinition {
306    pub r#type: String,
307    pub function: completion::ToolDefinition,
308}
309
310impl From<crate::completion::ToolDefinition> for ToolDefinition {
311    fn from(tool: crate::completion::ToolDefinition) -> Self {
312        Self {
313            r#type: "function".into(),
314            function: tool,
315        }
316    }
317}
318
319impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
320    type Error = CompletionError;
321
322    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
323        let choice = response.choices.first().ok_or_else(|| {
324            CompletionError::ResponseError("Response contained no choices".to_owned())
325        })?;
326        let content = match &choice.message {
327            Message::Assistant {
328                content,
329                tool_calls,
330                ..
331            } => {
332                let mut content = if content.trim().is_empty() {
333                    vec![]
334                } else {
335                    vec![completion::AssistantContent::text(content)]
336                };
337
338                content.extend(
339                    tool_calls
340                        .iter()
341                        .map(|call| {
342                            completion::AssistantContent::tool_call(
343                                &call.id,
344                                &call.function.name,
345                                call.function.arguments.clone(),
346                            )
347                        })
348                        .collect::<Vec<_>>(),
349                );
350                Ok(content)
351            }
352            _ => Err(CompletionError::ResponseError(
353                "Response did not contain a valid message or tool call".into(),
354            )),
355        }?;
356
357        let choice = OneOrMany::many(content).map_err(|_| {
358            CompletionError::ResponseError(
359                "Response contained no message or tool call (empty)".to_owned(),
360            )
361        })?;
362
363        Ok(completion::CompletionResponse {
364            choice,
365            raw_response: response,
366        })
367    }
368}
369
370/// The struct implementing the `CompletionModel` trait
371#[derive(Clone)]
372pub struct DeepSeekCompletionModel {
373    pub client: Client,
374    pub model: String,
375}
376
377impl DeepSeekCompletionModel {
378    fn create_completion_request(
379        &self,
380        completion_request: CompletionRequest,
381    ) -> Result<serde_json::Value, CompletionError> {
382        // Build up the order of messages (context, chat_history, prompt)
383        let mut partial_history = vec![];
384        if let Some(docs) = completion_request.normalized_documents() {
385            partial_history.push(docs);
386        }
387        partial_history.extend(completion_request.chat_history);
388
389        // Initialize full history with preamble (or empty if non-existent)
390        let mut full_history: Vec<Message> = completion_request
391            .preamble
392            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
393
394        // Convert and extend the rest of the history
395        full_history.extend(
396            partial_history
397                .into_iter()
398                .map(message::Message::try_into)
399                .collect::<Result<Vec<Vec<Message>>, _>>()?
400                .into_iter()
401                .flatten()
402                .collect::<Vec<_>>(),
403        );
404
405        let request = if completion_request.tools.is_empty() {
406            json!({
407                "model": self.model,
408                "messages": full_history,
409                "temperature": completion_request.temperature,
410            })
411        } else {
412            json!({
413                "model": self.model,
414                "messages": full_history,
415                "temperature": completion_request.temperature,
416                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
417                "tool_choice": "auto",
418            })
419        };
420
421        let request = if let Some(params) = completion_request.additional_params {
422            json_utils::merge(request, params)
423        } else {
424            request
425        };
426
427        Ok(request)
428    }
429}
430
431impl CompletionModel for DeepSeekCompletionModel {
432    type Response = CompletionResponse;
433
434    #[cfg_attr(feature = "worker", worker::send)]
435    async fn completion(
436        &self,
437        completion_request: CompletionRequest,
438    ) -> Result<
439        completion::CompletionResponse<CompletionResponse>,
440        crate::completion::CompletionError,
441    > {
442        let request = self.create_completion_request(completion_request)?;
443
444        let response = self
445            .client
446            .post("/chat/completions")
447            .json(&request)
448            .send()
449            .await?;
450
451        if response.status().is_success() {
452            let t = response.text().await?;
453            tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
454
455            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
456                ApiResponse::Ok(response) => response.try_into(),
457                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
458            }
459        } else {
460            Err(CompletionError::ProviderError(response.text().await?))
461        }
462    }
463}
464
465impl StreamingCompletionModel for DeepSeekCompletionModel {
466    async fn stream(
467        &self,
468        completion_request: CompletionRequest,
469    ) -> Result<StreamingResult, CompletionError> {
470        let mut request = self.create_completion_request(completion_request)?;
471
472        request = merge(request, json!({"stream": true}));
473
474        let builder = self.client.post("/v1/chat/completions").json(&request);
475        send_compatible_streaming_request(builder).await
476    }
477}
478
479// ================================================================
480// DeepSeek Completion API
481// ================================================================
482
483/// `deepseek-chat` completion model
484pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
485/// `deepseek-reasoner` completion model
486pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
487
488// Tests
489#[cfg(test)]
490mod tests {
491
492    use super::*;
493
494    #[test]
495    fn test_deserialize_vec_choice() {
496        let data = r#"[{
497            "finish_reason": "stop",
498            "index": 0,
499            "logprobs": null,
500            "message":{"role":"assistant","content":"Hello, world!"}
501            }]"#;
502
503        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
504        assert_eq!(choices.len(), 1);
505        match &choices.first().unwrap().message {
506            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
507            _ => panic!("Expected assistant message"),
508        }
509    }
510
511    #[test]
512    fn test_deserialize_deepseek_response() {
513        let data = r#"{"choices":[{
514            "finish_reason": "stop",
515            "index": 0,
516            "logprobs": null,
517            "message":{"role":"assistant","content":"Hello, world!"}
518            }]}"#;
519
520        let jd = &mut serde_json::Deserializer::from_str(data);
521        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
522        match result {
523            Ok(response) => match &response.choices.first().unwrap().message {
524                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
525                _ => panic!("Expected assistant message"),
526            },
527            Err(err) => {
528                panic!("Deserialization error at {}: {}", err.path(), err);
529            }
530        }
531    }
532
533    #[test]
534    fn test_deserialize_example_response() {
535        let data = r#"
536        {
537            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
538            "object": "chat.completion",
539            "created": 0,
540            "model": "deepseek-chat",
541            "choices": [
542                {
543                    "index": 0,
544                    "message": {
545                        "role": "assistant",
546                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
547                    },
548                    "logprobs": null,
549                    "finish_reason": "stop"
550                }
551            ],
552            "usage": {
553                "prompt_tokens": 13,
554                "completion_tokens": 32,
555                "total_tokens": 45,
556                "prompt_tokens_details": {
557                    "cached_tokens": 0
558                },
559                "prompt_cache_hit_tokens": 0,
560                "prompt_cache_miss_tokens": 13
561            },
562            "system_fingerprint": "fp_4b6881f2c5"
563        }
564        "#;
565        let jd = &mut serde_json::Deserializer::from_str(data);
566        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
567
568        match result {
569            Ok(response) => match &response.choices.first().unwrap().message {
570                Message::Assistant { content, .. } => assert_eq!(
571                    content,
572                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
573                ),
574                _ => panic!("Expected assistant message"),
575            },
576            Err(err) => {
577                panic!("Deserialization error at {}: {}", err.path(), err);
578            }
579        }
580    }
581
582    #[test]
583    fn test_serialize_deserialize_tool_call_message() {
584        let tool_call_choice_json = r#"
585            {
586              "finish_reason": "tool_calls",
587              "index": 0,
588              "logprobs": null,
589              "message": {
590                "content": "",
591                "role": "assistant",
592                "tool_calls": [
593                  {
594                    "function": {
595                      "arguments": "{\"x\":2,\"y\":5}",
596                      "name": "subtract"
597                    },
598                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
599                    "index": 0,
600                    "type": "function"
601                  }
602                ]
603              }
604            }
605        "#;
606
607        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
608
609        let expected_choice: Choice = Choice {
610            finish_reason: "tool_calls".to_string(),
611            index: 0,
612            logprobs: None,
613            message: Message::Assistant {
614                content: "".to_string(),
615                name: None,
616                tool_calls: vec![ToolCall {
617                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
618                    function: Function {
619                        name: "subtract".to_string(),
620                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
621                    },
622                    index: 0,
623                    r#type: ToolType::Function,
624                }],
625            },
626        };
627
628        assert_eq!(choice, expected_choice);
629    }
630}