rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11use crate::{
12    completion::{self, CompletionError, CompletionModel, CompletionRequest},
13    extractor::ExtractorBuilder,
14    json_utils, message, OneOrMany,
15};
16use reqwest::Client as HttpClient;
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20
21// ================================================================
22// Main DeepSeek Client
23// ================================================================
24const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
25
26#[derive(Clone)]
27pub struct Client {
28    pub base_url: String,
29    http_client: HttpClient,
30}
31
32impl Client {
33    // Create a new DeepSeek client from an API key.
34    pub fn new(api_key: &str) -> Self {
35        Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
36    }
37
38    // If you prefer the environment variable approach:
39    pub fn from_env() -> Self {
40        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
41        Self::new(&api_key)
42    }
43
44    // Handy for advanced usage, e.g. letting user override base_url or set timeouts:
45    pub fn from_url(api_key: &str, base_url: &str) -> Self {
46        // Possibly configure a custom HTTP client here if needed.
47        Self {
48            base_url: base_url.to_string(),
49            http_client: reqwest::Client::builder()
50                .default_headers({
51                    let mut headers = reqwest::header::HeaderMap::new();
52                    headers.insert(
53                        "Authorization",
54                        format!("Bearer {}", api_key)
55                            .parse()
56                            .expect("Bearer token should parse"),
57                    );
58                    headers
59                })
60                .build()
61                .expect("DeepSeek reqwest client should build"),
62        }
63    }
64
65    fn post(&self, path: &str) -> reqwest::RequestBuilder {
66        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
67        self.http_client.post(url)
68    }
69
70    /// Creates a DeepSeek completion model with the given `model_name`.
71    pub fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
72        DeepSeekCompletionModel {
73            client: self.clone(),
74            model: model_name.to_string(),
75        }
76    }
77
78    /// Optionally add an agent() convenience:
79    pub fn agent(&self, model_name: &str) -> crate::agent::AgentBuilder<DeepSeekCompletionModel> {
80        crate::agent::AgentBuilder::new(self.completion_model(model_name))
81    }
82
83    /// Create an extractor builder with the given completion model.
84    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
85        &self,
86        model: &str,
87    ) -> ExtractorBuilder<T, DeepSeekCompletionModel> {
88        ExtractorBuilder::new(self.completion_model(model))
89    }
90}
91
92#[derive(Debug, Deserialize)]
93struct ApiErrorResponse {
94    message: String,
95}
96
97#[derive(Debug, Deserialize)]
98#[serde(untagged)]
99enum ApiResponse<T> {
100    Ok(T),
101    Err(ApiErrorResponse),
102}
103
104impl From<ApiErrorResponse> for CompletionError {
105    fn from(err: ApiErrorResponse) -> Self {
106        CompletionError::ProviderError(err.message)
107    }
108}
109
110/// The response shape from the DeepSeek API
111#[derive(Clone, Debug, Serialize, Deserialize)]
112pub struct CompletionResponse {
113    // We'll match the JSON:
114    pub choices: Vec<Choice>,
115    // you may want usage or other fields
116}
117
118#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
119pub struct Choice {
120    pub index: usize,
121    pub message: Message,
122    pub logprobs: Option<serde_json::Value>,
123    pub finish_reason: String,
124}
125
126#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
127#[serde(tag = "role", rename_all = "lowercase")]
128pub enum Message {
129    System {
130        content: String,
131        #[serde(skip_serializing_if = "Option::is_none")]
132        name: Option<String>,
133    },
134    User {
135        content: String,
136        #[serde(skip_serializing_if = "Option::is_none")]
137        name: Option<String>,
138    },
139    Assistant {
140        content: String,
141        #[serde(skip_serializing_if = "Option::is_none")]
142        name: Option<String>,
143        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
144        tool_calls: Vec<ToolCall>,
145    },
146    #[serde(rename = "Tool")]
147    ToolResult {
148        tool_call_id: String,
149        content: String,
150    },
151}
152
153impl Message {
154    pub fn system(content: &str) -> Self {
155        Message::System {
156            content: content.to_owned(),
157            name: None,
158        }
159    }
160}
161
162impl From<message::ToolResult> for Message {
163    fn from(tool_result: message::ToolResult) -> Self {
164        let content = match tool_result.content.first() {
165            message::ToolResultContent::Text(text) => text.text,
166            message::ToolResultContent::Image(_) => String::from("[Image]"),
167        };
168
169        Message::ToolResult {
170            tool_call_id: tool_result.id,
171            content,
172        }
173    }
174}
175
176impl From<message::ToolCall> for ToolCall {
177    fn from(tool_call: message::ToolCall) -> Self {
178        Self {
179            id: tool_call.id,
180            // TODO: update index when we have it
181            index: 0,
182            r#type: ToolType::Function,
183            function: Function {
184                name: tool_call.function.name,
185                arguments: tool_call.function.arguments,
186            },
187        }
188    }
189}
190
191impl TryFrom<message::Message> for Vec<Message> {
192    type Error = message::MessageError;
193
194    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
195        match message {
196            message::Message::User { content } => {
197                // extract tool results
198                let mut messages = vec![];
199
200                let tool_results = content
201                    .clone()
202                    .into_iter()
203                    .filter_map(|content| match content {
204                        message::UserContent::ToolResult(tool_result) => {
205                            Some(Message::from(tool_result))
206                        }
207                        _ => None,
208                    })
209                    .collect::<Vec<_>>();
210
211                messages.extend(tool_results);
212
213                // extract text results
214                let text_messages = content
215                    .into_iter()
216                    .filter_map(|content| match content {
217                        message::UserContent::Text(text) => Some(Message::User {
218                            content: text.text,
219                            name: None,
220                        }),
221                        _ => None,
222                    })
223                    .collect::<Vec<_>>();
224                messages.extend(text_messages);
225
226                Ok(messages)
227            }
228            message::Message::Assistant { content } => {
229                let mut messages: Vec<Message> = vec![];
230
231                // extract tool calls
232                let tool_calls = content
233                    .clone()
234                    .into_iter()
235                    .filter_map(|content| match content {
236                        message::AssistantContent::ToolCall(tool_call) => {
237                            Some(ToolCall::from(tool_call))
238                        }
239                        _ => None,
240                    })
241                    .collect::<Vec<_>>();
242
243                // if we have tool calls, we add a new Assistant message with them
244                if !tool_calls.is_empty() {
245                    messages.push(Message::Assistant {
246                        content: "".to_string(),
247                        name: None,
248                        tool_calls,
249                    });
250                }
251
252                // extract text
253                let text_content = content
254                    .into_iter()
255                    .filter_map(|content| match content {
256                        message::AssistantContent::Text(text) => Some(Message::Assistant {
257                            content: text.text,
258                            name: None,
259                            tool_calls: vec![],
260                        }),
261                        _ => None,
262                    })
263                    .collect::<Vec<_>>();
264
265                messages.extend(text_content);
266
267                Ok(messages)
268            }
269        }
270    }
271}
272
273#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
274pub struct ToolCall {
275    pub id: String,
276    pub index: usize,
277    #[serde(default)]
278    pub r#type: ToolType,
279    pub function: Function,
280}
281
282#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
283pub struct Function {
284    pub name: String,
285    #[serde(with = "json_utils::stringified_json")]
286    pub arguments: serde_json::Value,
287}
288
289#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
290#[serde(rename_all = "lowercase")]
291pub enum ToolType {
292    #[default]
293    Function,
294}
295
296#[derive(Clone, Debug, Deserialize, Serialize)]
297pub struct ToolDefinition {
298    pub r#type: String,
299    pub function: completion::ToolDefinition,
300}
301
302impl From<crate::completion::ToolDefinition> for ToolDefinition {
303    fn from(tool: crate::completion::ToolDefinition) -> Self {
304        Self {
305            r#type: "function".into(),
306            function: tool,
307        }
308    }
309}
310
311impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
312    type Error = CompletionError;
313
314    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
315        let choice = response.choices.first().ok_or_else(|| {
316            CompletionError::ResponseError("Response contained no choices".to_owned())
317        })?;
318        let content = match &choice.message {
319            Message::Assistant {
320                content,
321                tool_calls,
322                ..
323            } => {
324                let mut content = if content.trim().is_empty() {
325                    vec![]
326                } else {
327                    vec![completion::AssistantContent::text(content)]
328                };
329
330                content.extend(
331                    tool_calls
332                        .iter()
333                        .map(|call| {
334                            completion::AssistantContent::tool_call(
335                                &call.function.name,
336                                &call.function.name,
337                                call.function.arguments.clone(),
338                            )
339                        })
340                        .collect::<Vec<_>>(),
341                );
342                Ok(content)
343            }
344            _ => Err(CompletionError::ResponseError(
345                "Response did not contain a valid message or tool call".into(),
346            )),
347        }?;
348
349        let choice = OneOrMany::many(content).map_err(|_| {
350            CompletionError::ResponseError(
351                "Response contained no message or tool call (empty)".to_owned(),
352            )
353        })?;
354
355        Ok(completion::CompletionResponse {
356            choice,
357            raw_response: response,
358        })
359    }
360}
361
362/// The struct implementing the `CompletionModel` trait
363#[derive(Clone)]
364pub struct DeepSeekCompletionModel {
365    pub client: Client,
366    pub model: String,
367}
368
369impl CompletionModel for DeepSeekCompletionModel {
370    type Response = CompletionResponse;
371
372    #[cfg_attr(feature = "worker", worker::send)]
373    async fn completion(
374        &self,
375        completion_request: CompletionRequest,
376    ) -> Result<
377        completion::CompletionResponse<CompletionResponse>,
378        crate::completion::CompletionError,
379    > {
380        // Add preamble to chat history (if available)
381        let mut full_history: Vec<Message> = match &completion_request.preamble {
382            Some(preamble) => vec![Message::system(preamble)],
383            None => vec![],
384        };
385
386        // Convert prompt to user message
387        let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
388
389        // Convert existing chat history
390        let chat_history: Vec<Message> = completion_request
391            .chat_history
392            .into_iter()
393            .map(|message| message.try_into())
394            .collect::<Result<Vec<Vec<Message>>, _>>()?
395            .into_iter()
396            .flatten()
397            .collect();
398
399        // Combine all messages into a single history
400        full_history.extend(chat_history);
401        full_history.extend(prompt);
402
403        let request = if completion_request.tools.is_empty() {
404            json!({
405                "model": self.model,
406                "messages": full_history,
407                "temperature": completion_request.temperature,
408            })
409        } else {
410            json!({
411                "model": self.model,
412                "messages": full_history,
413                "temperature": completion_request.temperature,
414                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
415                "tool_choice": "auto",
416            })
417        };
418
419        let response = self
420            .client
421            .post("/chat/completions")
422            .json(
423                &if let Some(params) = completion_request.additional_params {
424                    json_utils::merge(request, params)
425                } else {
426                    request
427                },
428            )
429            .send()
430            .await?;
431
432        if response.status().is_success() {
433            let t = response.text().await?;
434            tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
435
436            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
437                ApiResponse::Ok(response) => response.try_into(),
438                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
439            }
440        } else {
441            Err(CompletionError::ProviderError(response.text().await?))
442        }
443    }
444}
445
446// ================================================================
447// DeepSeek Completion API
448// ================================================================
449
450/// `deepseek-chat` completion model
451pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
452/// `deepseek-reasoner` completion model
453pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
454
455// Tests
456#[cfg(test)]
457mod tests {
458
459    use super::*;
460
461    #[test]
462    fn test_deserialize_vec_choice() {
463        let data = r#"[{
464            "finish_reason": "stop",
465            "index": 0,
466            "logprobs": null,
467            "message":{"role":"assistant","content":"Hello, world!"}
468            }]"#;
469
470        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
471        assert_eq!(choices.len(), 1);
472        match &choices.first().unwrap().message {
473            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
474            _ => panic!("Expected assistant message"),
475        }
476    }
477
478    #[test]
479    fn test_deserialize_deepseek_response() {
480        let data = r#"{"choices":[{
481            "finish_reason": "stop",
482            "index": 0,
483            "logprobs": null,
484            "message":{"role":"assistant","content":"Hello, world!"}
485            }]}"#;
486
487        let jd = &mut serde_json::Deserializer::from_str(data);
488        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
489        match result {
490            Ok(response) => match &response.choices.first().unwrap().message {
491                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
492                _ => panic!("Expected assistant message"),
493            },
494            Err(err) => {
495                panic!("Deserialization error at {}: {}", err.path(), err);
496            }
497        }
498    }
499
500    #[test]
501    fn test_deserialize_example_response() {
502        let data = r#"
503        {
504            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
505            "object": "chat.completion",
506            "created": 0,
507            "model": "deepseek-chat",
508            "choices": [
509                {
510                    "index": 0,
511                    "message": {
512                        "role": "assistant",
513                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
514                    },
515                    "logprobs": null,
516                    "finish_reason": "stop"
517                }
518            ],
519            "usage": {
520                "prompt_tokens": 13,
521                "completion_tokens": 32,
522                "total_tokens": 45,
523                "prompt_tokens_details": {
524                    "cached_tokens": 0
525                },
526                "prompt_cache_hit_tokens": 0,
527                "prompt_cache_miss_tokens": 13
528            },
529            "system_fingerprint": "fp_4b6881f2c5"
530        }
531        "#;
532        let jd = &mut serde_json::Deserializer::from_str(data);
533        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
534
535        match result {
536            Ok(response) => match &response.choices.first().unwrap().message {
537                Message::Assistant { content, .. } => assert_eq!(
538                    content,
539                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
540                ),
541                _ => panic!("Expected assistant message"),
542            },
543            Err(err) => {
544                panic!("Deserialization error at {}: {}", err.path(), err);
545            }
546        }
547    }
548
549    #[test]
550    fn test_serialize_deserialize_tool_call_message() {
551        let tool_call_choice_json = r#"
552            {
553              "finish_reason": "tool_calls",
554              "index": 0,
555              "logprobs": null,
556              "message": {
557                "content": "",
558                "role": "assistant",
559                "tool_calls": [
560                  {
561                    "function": {
562                      "arguments": "{\"x\":2,\"y\":5}",
563                      "name": "subtract"
564                    },
565                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
566                    "index": 0,
567                    "type": "function"
568                  }
569                ]
570              }
571            }
572        "#;
573
574        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
575
576        let expected_choice: Choice = Choice {
577            finish_reason: "tool_calls".to_string(),
578            index: 0,
579            logprobs: None,
580            message: Message::Assistant {
581                content: "".to_string(),
582                name: None,
583                tool_calls: vec![ToolCall {
584                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
585                    function: Function {
586                        name: "subtract".to_string(),
587                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
588                    },
589                    index: 0,
590                    r#type: ToolType::Function,
591                }],
592            },
593        };
594
595        assert_eq!(choice, expected_choice);
596    }
597}