Skip to main content

rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use bytes::Bytes;
13use http::Request;
14use tracing::{Instrument, Level, enabled, info_span};
15
16use crate::client::{
17    self, BearerAuth, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider,
18    ProviderBuilder, ProviderClient,
19};
20use crate::completion::GetTokenUsage;
21use crate::http_client::{self, HttpClientExt};
22use crate::message::{Document, DocumentSourceKind};
23use crate::model::{Model, ModelList, ModelListingError};
24use crate::providers::internal::openai_chat_completions_compatible::{
25    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
26};
27use crate::{
28    OneOrMany,
29    completion::{self, CompletionError, CompletionRequest},
30    json_utils, message,
31    wasm_compat::{WasmCompatSend, WasmCompatSync},
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51    const VERIFY_PATH: &'static str = "/user/balance";
52}
53
54impl<H> Capabilities<H> for DeepSeekExt {
55    type Completion = Capable<CompletionModel<H>>;
56    type Embeddings = Nothing;
57    type Transcription = Nothing;
58    type ModelListing = Capable<DeepSeekModelLister<H>>;
59    #[cfg(feature = "image")]
60    type ImageGeneration = Nothing;
61    #[cfg(feature = "audio")]
62    type AudioGeneration = Nothing;
63}
64
65impl DebugExt for DeepSeekExt {}
66
67impl ProviderBuilder for DeepSeekExtBuilder {
68    type Extension<H>
69        = DeepSeekExt
70    where
71        H: HttpClientExt;
72    type ApiKey = DeepSeekApiKey;
73
74    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
75
76    fn build<H>(
77        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
78    ) -> http_client::Result<Self::Extension<H>>
79    where
80        H: HttpClientExt,
81    {
82        Ok(DeepSeekExt)
83    }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
87pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
88
89impl ProviderClient for Client {
90    type Input = DeepSeekApiKey;
91    type Error = crate::client::ProviderClientError;
92
93    // If you prefer the environment variable approach:
94    fn from_env() -> Result<Self, Self::Error> {
95        let api_key = crate::client::required_env_var("DEEPSEEK_API_KEY")?;
96        let mut client_builder = Self::builder();
97        client_builder.headers_mut().insert(
98            http::header::CONTENT_TYPE,
99            http::HeaderValue::from_static("application/json"),
100        );
101        let client_builder = client_builder.api_key(&api_key);
102        client_builder.build().map_err(Into::into)
103    }
104
105    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
106        Self::new(input).map_err(Into::into)
107    }
108}
109
110#[derive(Debug, Deserialize)]
111struct ApiErrorResponse {
112    message: String,
113}
114
115#[derive(Debug, Deserialize)]
116#[serde(untagged)]
117enum ApiResponse<T> {
118    Ok(T),
119    Err(ApiErrorResponse),
120}
121
122impl From<ApiErrorResponse> for CompletionError {
123    fn from(err: ApiErrorResponse) -> Self {
124        CompletionError::ProviderError(err.message)
125    }
126}
127
128/// The response shape from the DeepSeek API
129#[derive(Clone, Debug, Serialize, Deserialize)]
130pub struct CompletionResponse {
131    // We'll match the JSON:
132    pub choices: Vec<Choice>,
133    pub usage: Usage,
134    // you may want other fields
135}
136
137#[derive(Clone, Debug, Serialize, Deserialize, Default)]
138pub struct Usage {
139    pub completion_tokens: u32,
140    pub prompt_tokens: u32,
141    pub prompt_cache_hit_tokens: u32,
142    pub prompt_cache_miss_tokens: u32,
143    pub total_tokens: u32,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub completion_tokens_details: Option<CompletionTokensDetails>,
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub prompt_tokens_details: Option<PromptTokensDetails>,
148}
149
150impl GetTokenUsage for Usage {
151    fn token_usage(&self) -> Option<crate::completion::Usage> {
152        Some(crate::providers::internal::completion_usage(
153            self.prompt_tokens as u64,
154            self.completion_tokens as u64,
155            self.total_tokens as u64,
156            self.prompt_tokens_details
157                .as_ref()
158                .and_then(|details| details.cached_tokens)
159                .map(u64::from)
160                .unwrap_or(0),
161        ))
162    }
163}
164
165#[derive(Clone, Debug, Serialize, Deserialize, Default)]
166pub struct CompletionTokensDetails {
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub reasoning_tokens: Option<u32>,
169}
170
171#[derive(Clone, Debug, Serialize, Deserialize, Default)]
172pub struct PromptTokensDetails {
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub cached_tokens: Option<u32>,
175}
176
177#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
178pub struct Choice {
179    pub index: usize,
180    pub message: Message,
181    pub logprobs: Option<serde_json::Value>,
182    pub finish_reason: String,
183}
184
185#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
186#[serde(tag = "role", rename_all = "lowercase")]
187pub enum Message {
188    System {
189        content: String,
190        #[serde(skip_serializing_if = "Option::is_none")]
191        name: Option<String>,
192    },
193    User {
194        content: String,
195        #[serde(skip_serializing_if = "Option::is_none")]
196        name: Option<String>,
197    },
198    Assistant {
199        content: String,
200        #[serde(skip_serializing_if = "Option::is_none")]
201        name: Option<String>,
202        #[serde(
203            default,
204            deserialize_with = "json_utils::null_or_vec",
205            skip_serializing_if = "Vec::is_empty"
206        )]
207        tool_calls: Vec<ToolCall>,
208        /// only exists on `deepseek-reasoner` model at time of addition
209        #[serde(skip_serializing_if = "Option::is_none")]
210        reasoning_content: Option<String>,
211    },
212    #[serde(rename = "tool")]
213    ToolResult {
214        tool_call_id: String,
215        content: String,
216    },
217}
218
219impl Message {
220    pub fn system(content: &str) -> Self {
221        Message::System {
222            content: content.to_owned(),
223            name: None,
224        }
225    }
226}
227
228impl From<message::ToolResult> for Message {
229    fn from(tool_result: message::ToolResult) -> Self {
230        let content = match tool_result.content.first() {
231            message::ToolResultContent::Text(text) => text.text,
232            message::ToolResultContent::Image(_) => String::from("[Image]"),
233        };
234
235        Message::ToolResult {
236            tool_call_id: tool_result.id,
237            content,
238        }
239    }
240}
241
242impl From<message::ToolCall> for ToolCall {
243    fn from(tool_call: message::ToolCall) -> Self {
244        Self {
245            id: tool_call.id,
246            // TODO: update index when we have it
247            index: 0,
248            r#type: ToolType::Function,
249            function: Function {
250                name: tool_call.function.name,
251                arguments: tool_call.function.arguments,
252            },
253        }
254    }
255}
256
257impl TryFrom<message::Message> for Vec<Message> {
258    type Error = message::MessageError;
259
260    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
261        match message {
262            message::Message::System { content } => Ok(vec![Message::System {
263                content,
264                name: None,
265            }]),
266            message::Message::User { content } => {
267                // extract tool results
268                let mut messages = vec![];
269
270                let tool_results = content
271                    .clone()
272                    .into_iter()
273                    .filter_map(|content| match content {
274                        message::UserContent::ToolResult(tool_result) => {
275                            Some(Message::from(tool_result))
276                        }
277                        _ => None,
278                    })
279                    .collect::<Vec<_>>();
280
281                messages.extend(tool_results);
282
283                let text_content: String = content
284                    .into_iter()
285                    .filter_map(|content| match content {
286                        message::UserContent::Text(text) => Some(text.text),
287                        message::UserContent::Document(Document {
288                            data:
289                                DocumentSourceKind::Base64(content)
290                                | DocumentSourceKind::String(content),
291                            ..
292                        }) => Some(content),
293                        _ => None,
294                    })
295                    .collect::<Vec<_>>()
296                    .join("\n");
297
298                if !text_content.is_empty() {
299                    messages.push(Message::User {
300                        content: text_content,
301                        name: None,
302                    });
303                }
304
305                Ok(messages)
306            }
307            message::Message::Assistant { content, .. } => {
308                let mut text_content = String::new();
309                let mut reasoning_content = String::new();
310                let mut tool_calls = Vec::new();
311
312                for item in content.iter() {
313                    match item {
314                        message::AssistantContent::Text(text) => {
315                            text_content.push_str(text.text());
316                        }
317                        message::AssistantContent::Reasoning(reasoning) => {
318                            reasoning_content.push_str(&reasoning.display_text());
319                        }
320                        message::AssistantContent::ToolCall(tool_call) => {
321                            tool_calls.push(ToolCall::from(tool_call.clone()));
322                        }
323                        _ => {}
324                    }
325                }
326
327                let reasoning = if reasoning_content.is_empty() {
328                    None
329                } else {
330                    Some(reasoning_content)
331                };
332
333                Ok(vec![Message::Assistant {
334                    content: text_content,
335                    name: None,
336                    tool_calls,
337                    reasoning_content: reasoning,
338                }])
339            }
340        }
341    }
342}
343
344#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
345pub struct ToolCall {
346    pub id: String,
347    pub index: usize,
348    #[serde(default)]
349    pub r#type: ToolType,
350    pub function: Function,
351}
352
353#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
354pub struct Function {
355    pub name: String,
356    #[serde(with = "json_utils::stringified_json")]
357    pub arguments: serde_json::Value,
358}
359
360#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
361#[serde(rename_all = "lowercase")]
362pub enum ToolType {
363    #[default]
364    Function,
365}
366
367#[derive(Clone, Debug, Deserialize, Serialize)]
368pub struct ToolDefinition {
369    pub r#type: String,
370    pub function: completion::ToolDefinition,
371}
372
373impl From<crate::completion::ToolDefinition> for ToolDefinition {
374    fn from(tool: crate::completion::ToolDefinition) -> Self {
375        Self {
376            r#type: "function".into(),
377            function: tool,
378        }
379    }
380}
381
382impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
383    type Error = CompletionError;
384
385    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
386        let choice = response.choices.first().ok_or_else(|| {
387            CompletionError::ResponseError("Response contained no choices".to_owned())
388        })?;
389        let content = match &choice.message {
390            Message::Assistant {
391                content,
392                tool_calls,
393                reasoning_content,
394                ..
395            } => {
396                let mut content = if content.trim().is_empty() {
397                    vec![]
398                } else {
399                    vec![completion::AssistantContent::text(content)]
400                };
401
402                content.extend(
403                    tool_calls
404                        .iter()
405                        .map(|call| {
406                            completion::AssistantContent::tool_call(
407                                &call.id,
408                                &call.function.name,
409                                call.function.arguments.clone(),
410                            )
411                        })
412                        .collect::<Vec<_>>(),
413                );
414
415                if let Some(reasoning_content) = reasoning_content {
416                    content.push(completion::AssistantContent::reasoning(reasoning_content));
417                }
418
419                Ok(content)
420            }
421            _ => Err(CompletionError::ResponseError(
422                "Response did not contain a valid message or tool call".into(),
423            )),
424        }?;
425
426        let choice = OneOrMany::many(content).map_err(|_| {
427            CompletionError::ResponseError(
428                "Response contained no message or tool call (empty)".to_owned(),
429            )
430        })?;
431
432        let usage = completion::Usage {
433            input_tokens: response.usage.prompt_tokens as u64,
434            output_tokens: response.usage.completion_tokens as u64,
435            total_tokens: response.usage.total_tokens as u64,
436            cached_input_tokens: response
437                .usage
438                .prompt_tokens_details
439                .as_ref()
440                .and_then(|d| d.cached_tokens)
441                .map(|c| c as u64)
442                .unwrap_or(0),
443            cache_creation_input_tokens: 0,
444        };
445
446        Ok(completion::CompletionResponse {
447            choice,
448            usage,
449            raw_response: response,
450            message_id: None,
451        })
452    }
453}
454
455#[derive(Debug, Serialize, Deserialize)]
456pub(super) struct DeepseekCompletionRequest {
457    model: String,
458    pub messages: Vec<Message>,
459    #[serde(skip_serializing_if = "Option::is_none")]
460    temperature: Option<f64>,
461    #[serde(skip_serializing_if = "Vec::is_empty")]
462    tools: Vec<ToolDefinition>,
463    #[serde(skip_serializing_if = "Option::is_none")]
464    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
465    #[serde(flatten, skip_serializing_if = "Option::is_none")]
466    pub additional_params: Option<serde_json::Value>,
467}
468
469impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
470    type Error = CompletionError;
471
472    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
473        if req.output_schema.is_some() {
474            tracing::warn!("Structured outputs currently not supported for DeepSeek");
475        }
476        let model = req.model.clone().unwrap_or_else(|| model.to_string());
477        let mut full_history: Vec<Message> = match &req.preamble {
478            Some(preamble) => vec![Message::system(preamble)],
479            None => vec![],
480        };
481
482        if let Some(docs) = req.normalized_documents() {
483            let docs: Vec<Message> = docs.try_into()?;
484            full_history.extend(docs);
485        }
486
487        let chat_history: Vec<Message> = req
488            .chat_history
489            .clone()
490            .into_iter()
491            .map(|message| message.try_into())
492            .collect::<Result<Vec<Vec<Message>>, _>>()?
493            .into_iter()
494            .flatten()
495            .collect();
496
497        full_history.extend(chat_history);
498
499        let tool_choice = req
500            .tool_choice
501            .clone()
502            .map(crate::providers::openrouter::ToolChoice::try_from)
503            .transpose()?;
504
505        Ok(Self {
506            model: model.to_string(),
507            messages: full_history,
508            temperature: req.temperature,
509            tools: req
510                .tools
511                .clone()
512                .into_iter()
513                .map(ToolDefinition::from)
514                .collect::<Vec<_>>(),
515            tool_choice,
516            additional_params: req.additional_params,
517        })
518    }
519}
520
521/// The struct implementing the `CompletionModel` trait
522#[derive(Clone)]
523pub struct CompletionModel<T = reqwest::Client> {
524    pub client: Client<T>,
525    pub model: String,
526}
527
528impl<T> completion::CompletionModel for CompletionModel<T>
529where
530    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
531{
532    type Response = CompletionResponse;
533    type StreamingResponse = StreamingCompletionResponse;
534
535    type Client = Client<T>;
536
537    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
538        Self {
539            client: client.clone(),
540            model: model.into().to_string(),
541        }
542    }
543
544    async fn completion(
545        &self,
546        completion_request: CompletionRequest,
547    ) -> Result<
548        completion::CompletionResponse<CompletionResponse>,
549        crate::completion::CompletionError,
550    > {
551        let span = if tracing::Span::current().is_disabled() {
552            info_span!(
553                target: "rig::completions",
554                "chat",
555                gen_ai.operation.name = "chat",
556                gen_ai.provider.name = "deepseek",
557                gen_ai.request.model = self.model,
558                gen_ai.system_instructions = tracing::field::Empty,
559                gen_ai.response.id = tracing::field::Empty,
560                gen_ai.response.model = tracing::field::Empty,
561                gen_ai.usage.output_tokens = tracing::field::Empty,
562                gen_ai.usage.input_tokens = tracing::field::Empty,
563                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
564            )
565        } else {
566            tracing::Span::current()
567        };
568
569        span.record("gen_ai.system_instructions", &completion_request.preamble);
570
571        let request =
572            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
573
574        if enabled!(Level::TRACE) {
575            tracing::trace!(target: "rig::completions",
576                "DeepSeek completion request: {}",
577                serde_json::to_string_pretty(&request)?
578            );
579        }
580
581        let body = serde_json::to_vec(&request)?;
582        let req = self
583            .client
584            .post("/chat/completions")?
585            .body(body)
586            .map_err(|e| CompletionError::HttpError(e.into()))?;
587
588        async move {
589            let response = self.client.send::<_, Bytes>(req).await?;
590            let status = response.status();
591            let response_body = response.into_body().into_future().await?.to_vec();
592
593            if status.is_success() {
594                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
595                    ApiResponse::Ok(response) => {
596                        let span = tracing::Span::current();
597                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
598                        span.record(
599                            "gen_ai.usage.output_tokens",
600                            response.usage.completion_tokens,
601                        );
602                        span.record(
603                            "gen_ai.usage.cache_read.input_tokens",
604                            response
605                                .usage
606                                .prompt_tokens_details
607                                .as_ref()
608                                .and_then(|d| d.cached_tokens)
609                                .unwrap_or(0),
610                        );
611                        if enabled!(Level::TRACE) {
612                            tracing::trace!(target: "rig::completions",
613                                "DeepSeek completion response: {}",
614                                serde_json::to_string_pretty(&response)?
615                            );
616                        }
617                        response.try_into()
618                    }
619                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
620                }
621            } else {
622                Err(CompletionError::ProviderError(
623                    String::from_utf8_lossy(&response_body).to_string(),
624                ))
625            }
626        }
627        .instrument(span)
628        .await
629    }
630
631    async fn stream(
632        &self,
633        completion_request: CompletionRequest,
634    ) -> Result<
635        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
636        CompletionError,
637    > {
638        let preamble = completion_request.preamble.clone();
639        let mut request =
640            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
641
642        let params = json_utils::merge(
643            request.additional_params.unwrap_or(serde_json::json!({})),
644            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
645        );
646
647        request.additional_params = Some(params);
648
649        if enabled!(Level::TRACE) {
650            tracing::trace!(target: "rig::completions",
651                "DeepSeek streaming completion request: {}",
652                serde_json::to_string_pretty(&request)?
653            );
654        }
655
656        let body = serde_json::to_vec(&request)?;
657
658        let req = self
659            .client
660            .post("/chat/completions")?
661            .body(body)
662            .map_err(|e| CompletionError::HttpError(e.into()))?;
663
664        let span = if tracing::Span::current().is_disabled() {
665            info_span!(
666                target: "rig::completions",
667                "chat_streaming",
668                gen_ai.operation.name = "chat_streaming",
669                gen_ai.provider.name = "deepseek",
670                gen_ai.request.model = self.model,
671                gen_ai.system_instructions = preamble,
672                gen_ai.response.id = tracing::field::Empty,
673                gen_ai.response.model = tracing::field::Empty,
674                gen_ai.usage.output_tokens = tracing::field::Empty,
675                gen_ai.usage.input_tokens = tracing::field::Empty,
676                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
677            )
678        } else {
679            tracing::Span::current()
680        };
681
682        tracing::Instrument::instrument(
683            send_compatible_streaming_request(self.client.clone(), req),
684            span,
685        )
686        .await
687    }
688}
689
690#[derive(Deserialize, Debug)]
691pub struct StreamingDelta {
692    #[serde(default)]
693    content: Option<String>,
694    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
695    tool_calls: Vec<StreamingToolCall>,
696    reasoning_content: Option<String>,
697}
698
699#[derive(Deserialize, Debug)]
700struct StreamingChoice {
701    delta: StreamingDelta,
702}
703
704#[derive(Deserialize, Debug)]
705struct StreamingCompletionChunk {
706    id: Option<String>,
707    model: Option<String>,
708    choices: Vec<StreamingChoice>,
709    usage: Option<Usage>,
710}
711
712#[derive(Clone, Deserialize, Serialize, Debug)]
713pub struct StreamingCompletionResponse {
714    pub usage: Usage,
715}
716
717impl GetTokenUsage for StreamingCompletionResponse {
718    fn token_usage(&self) -> Option<crate::completion::Usage> {
719        self.usage.token_usage()
720    }
721}
722
723#[derive(Clone, Copy)]
724struct DeepSeekCompatibleProfile;
725
726impl CompatibleStreamProfile for DeepSeekCompatibleProfile {
727    type Usage = Usage;
728    type Detail = ();
729    type FinalResponse = StreamingCompletionResponse;
730
731    fn normalize_chunk(
732        &self,
733        data: &str,
734    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
735        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
736            Ok(data) => data,
737            Err(error) => {
738                tracing::debug!(
739                    "Couldn't parse SSE payload as StreamingCompletionChunk: {:?}",
740                    error
741                );
742                return Ok(None);
743            }
744        };
745
746        Ok(Some(
747            openai_chat_completions_compatible::normalize_first_choice_chunk(
748                data.id,
749                data.model,
750                data.usage,
751                &data.choices,
752                |choice| CompatibleChoiceData {
753                    finish_reason: CompatibleFinishReason::Other,
754                    text: choice.delta.content.clone(),
755                    reasoning: choice.delta.reasoning_content.clone(),
756                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
757                        &choice.delta.tool_calls,
758                    ),
759                    details: Vec::new(),
760                },
761            ),
762        ))
763    }
764
765    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
766        StreamingCompletionResponse { usage }
767    }
768
769    fn uses_distinct_tool_call_eviction(&self) -> bool {
770        true
771    }
772
773    fn emits_complete_single_chunk_tool_calls(&self) -> bool {
774        true
775    }
776}
777
778pub async fn send_compatible_streaming_request<T>(
779    http_client: T,
780    req: Request<Vec<u8>>,
781) -> Result<
782    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
783    CompletionError,
784>
785where
786    T: HttpClientExt + Clone + 'static,
787{
788    openai_chat_completions_compatible::send_compatible_streaming_request(
789        http_client,
790        req,
791        DeepSeekCompatibleProfile,
792    )
793    .await
794}
795
796#[derive(Debug, Deserialize)]
797struct ListModelsResponse {
798    data: Vec<ListModelEntry>,
799}
800
801#[derive(Debug, Deserialize)]
802struct ListModelEntry {
803    id: String,
804    owned_by: String,
805}
806
807impl From<ListModelEntry> for Model {
808    fn from(value: ListModelEntry) -> Self {
809        let mut model = Model::from_id(value.id);
810        model.owned_by = Some(value.owned_by);
811        model
812    }
813}
814
815/// [`ModelLister`] implementation for the DeepSeek API (`GET /models`).
816#[derive(Clone)]
817pub struct DeepSeekModelLister<H = reqwest::Client> {
818    client: Client<H>,
819}
820
821impl<H> ModelLister<H> for DeepSeekModelLister<H>
822where
823    H: HttpClientExt + WasmCompatSend + WasmCompatSync + 'static,
824{
825    type Client = Client<H>;
826
827    fn new(client: Self::Client) -> Self {
828        Self { client }
829    }
830
831    async fn list_all(&self) -> Result<ModelList, ModelListingError> {
832        let path = "/models";
833        let req = self.client.get(path)?.body(http_client::NoBody)?;
834        let response = self
835            .client
836            .send::<_, Vec<u8>>(req)
837            .await
838            .map_err(|error| match error {
839                http_client::Error::InvalidStatusCodeWithMessage(status, message) => {
840                    ModelListingError::api_error_with_context(
841                        "DeepSeek",
842                        path,
843                        status.as_u16(),
844                        message.as_bytes(),
845                    )
846                }
847                other => ModelListingError::from(other),
848            })?;
849
850        if !response.status().is_success() {
851            let status_code = response.status().as_u16();
852            let body = response.into_body().await?;
853            return Err(ModelListingError::api_error_with_context(
854                "DeepSeek",
855                path,
856                status_code,
857                &body,
858            ));
859        }
860
861        let body = response.into_body().await?;
862        let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
863            ModelListingError::parse_error_with_context("DeepSeek", path, &error, &body)
864        })?;
865
866        let models = api_resp.data.into_iter().map(Model::from).collect();
867
868        Ok(ModelList::new(models))
869    }
870}
871
872// ================================================================
873// DeepSeek Completion API
874// ================================================================
875#[deprecated(
876    note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
877    For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
878    respectively."
879)]
880pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
881#[deprecated(
882    note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
883    For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
884    respectively."
885)]
886pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
887pub const DEEPSEEK_V4_FLASH: &str = "deepseek-v4-flash";
888pub const DEEPSEEK_V4_PRO: &str = "deepseek-v4-pro";
889
890// Tests
891#[cfg(test)]
892mod tests {
893    use super::*;
894    use crate::client::ModelListingClient;
895    use crate::http_client::{LazyBody, MultipartForm, Request as HttpRequest, Response};
896    use bytes::Bytes;
897    use std::future::{self, Future};
898    use std::sync::{Arc, Mutex};
899
900    #[test]
901    fn test_deserialize_vec_choice() {
902        let data = r#"[{
903            "finish_reason": "stop",
904            "index": 0,
905            "logprobs": null,
906            "message":{"role":"assistant","content":"Hello, world!"}
907            }]"#;
908
909        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
910        assert_eq!(choices.len(), 1);
911        match &choices.first().unwrap().message {
912            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
913            _ => panic!("Expected assistant message"),
914        }
915    }
916
917    #[test]
918    fn test_deserialize_deepseek_response() {
919        let data = r#"{
920            "choices":[{
921                "finish_reason": "stop",
922                "index": 0,
923                "logprobs": null,
924                "message":{"role":"assistant","content":"Hello, world!"}
925            }],
926            "usage": {
927                "completion_tokens": 0,
928                "prompt_tokens": 0,
929                "prompt_cache_hit_tokens": 0,
930                "prompt_cache_miss_tokens": 0,
931                "total_tokens": 0
932            }
933        }"#;
934
935        let jd = &mut serde_json::Deserializer::from_str(data);
936        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
937        match result {
938            Ok(response) => match &response.choices.first().unwrap().message {
939                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
940                _ => panic!("Expected assistant message"),
941            },
942            Err(err) => {
943                panic!("Deserialization error at {}: {}", err.path(), err);
944            }
945        }
946    }
947
948    #[test]
949    fn test_deserialize_example_response() {
950        let data = r#"
951        {
952            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
953            "object": "chat.completion",
954            "created": 0,
955            "model": "deepseek-chat",
956            "choices": [
957                {
958                    "index": 0,
959                    "message": {
960                        "role": "assistant",
961                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
962                    },
963                    "logprobs": null,
964                    "finish_reason": "stop"
965                }
966            ],
967            "usage": {
968                "prompt_tokens": 13,
969                "completion_tokens": 32,
970                "total_tokens": 45,
971                "prompt_tokens_details": {
972                    "cached_tokens": 0
973                },
974                "prompt_cache_hit_tokens": 0,
975                "prompt_cache_miss_tokens": 13
976            },
977            "system_fingerprint": "fp_4b6881f2c5"
978        }
979        "#;
980        let jd = &mut serde_json::Deserializer::from_str(data);
981        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
982
983        match result {
984            Ok(response) => match &response.choices.first().unwrap().message {
985                Message::Assistant { content, .. } => assert_eq!(
986                    content,
987                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
988                ),
989                _ => panic!("Expected assistant message"),
990            },
991            Err(err) => {
992                panic!("Deserialization error at {}: {}", err.path(), err);
993            }
994        }
995    }
996
997    #[test]
998    fn test_serialize_deserialize_tool_call_message() {
999        let tool_call_choice_json = r#"
1000            {
1001              "finish_reason": "tool_calls",
1002              "index": 0,
1003              "logprobs": null,
1004              "message": {
1005                "content": "",
1006                "role": "assistant",
1007                "tool_calls": [
1008                  {
1009                    "function": {
1010                      "arguments": "{\"x\":2,\"y\":5}",
1011                      "name": "subtract"
1012                    },
1013                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1014                    "index": 0,
1015                    "type": "function"
1016                  }
1017                ]
1018              }
1019            }
1020        "#;
1021
1022        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1023
1024        let expected_choice: Choice = Choice {
1025            finish_reason: "tool_calls".to_string(),
1026            index: 0,
1027            logprobs: None,
1028            message: Message::Assistant {
1029                content: "".to_string(),
1030                name: None,
1031                tool_calls: vec![ToolCall {
1032                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1033                    function: Function {
1034                        name: "subtract".to_string(),
1035                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1036                    },
1037                    index: 0,
1038                    r#type: ToolType::Function,
1039                }],
1040                reasoning_content: None,
1041            },
1042        };
1043
1044        assert_eq!(choice, expected_choice);
1045    }
1046    #[test]
1047    fn test_user_message_multiple_text_items_merged() {
1048        use crate::completion::message::{Message as RigMessage, UserContent};
1049
1050        let rig_msg = RigMessage::User {
1051            content: OneOrMany::many(vec![
1052                UserContent::text("first part"),
1053                UserContent::text("second part"),
1054            ])
1055            .expect("content should not be empty"),
1056        };
1057
1058        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1059
1060        let user_messages: Vec<&Message> = messages
1061            .iter()
1062            .filter(|m| matches!(m, Message::User { .. }))
1063            .collect();
1064
1065        assert_eq!(
1066            user_messages.len(),
1067            1,
1068            "multiple text items should produce a single user message"
1069        );
1070        match &user_messages[0] {
1071            Message::User { content, .. } => {
1072                assert_eq!(content, "first part\nsecond part");
1073            }
1074            _ => unreachable!(),
1075        }
1076    }
1077
1078    #[test]
1079    fn test_assistant_message_with_reasoning_and_tool_calls() {
1080        use crate::completion::message::{AssistantContent, Message as RigMessage};
1081
1082        let rig_msg = RigMessage::Assistant {
1083            id: None,
1084            content: OneOrMany::many(vec![
1085                AssistantContent::reasoning("thinking about the problem"),
1086                AssistantContent::text("I'll call the tool"),
1087                AssistantContent::tool_call(
1088                    "call_1",
1089                    "subtract",
1090                    serde_json::json!({"x": 2, "y": 5}),
1091                ),
1092            ])
1093            .expect("content should not be empty"),
1094        };
1095
1096        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1097
1098        assert_eq!(messages.len(), 1, "should produce exactly one message");
1099        match &messages[0] {
1100            Message::Assistant {
1101                content,
1102                tool_calls,
1103                reasoning_content,
1104                ..
1105            } => {
1106                assert_eq!(content, "I'll call the tool");
1107                assert_eq!(
1108                    reasoning_content.as_deref(),
1109                    Some("thinking about the problem")
1110                );
1111                assert_eq!(tool_calls.len(), 1);
1112                assert_eq!(tool_calls[0].function.name, "subtract");
1113            }
1114            _ => panic!("Expected assistant message"),
1115        }
1116    }
1117
1118    #[test]
1119    fn test_assistant_message_without_reasoning() {
1120        use crate::completion::message::{AssistantContent, Message as RigMessage};
1121
1122        let rig_msg = RigMessage::Assistant {
1123            id: None,
1124            content: OneOrMany::many(vec![
1125                AssistantContent::text("calling tool"),
1126                AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1127            ])
1128            .expect("content should not be empty"),
1129        };
1130
1131        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1132
1133        assert_eq!(messages.len(), 1);
1134        match &messages[0] {
1135            Message::Assistant {
1136                reasoning_content,
1137                tool_calls,
1138                ..
1139            } => {
1140                assert!(reasoning_content.is_none());
1141                assert_eq!(tool_calls.len(), 1);
1142            }
1143            _ => panic!("Expected assistant message"),
1144        }
1145    }
1146
1147    #[test]
1148    fn test_client_initialization() {
1149        let _client =
1150            crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1151        let _client_from_builder = crate::providers::deepseek::Client::builder()
1152            .api_key("dummy-key")
1153            .build()
1154            .expect("Client::builder() failed");
1155    }
1156
1157    #[test]
1158    fn test_deserialize_list_models_response() {
1159        let data = r#"{
1160            "object": "list",
1161            "data": [
1162                {
1163                    "id": "deepseek-v4-flash",
1164                    "object": "model",
1165                    "owned_by": "deepseek"
1166                },
1167                {
1168                    "id": "deepseek-v4-pro",
1169                    "object": "model",
1170                    "owned_by": "deepseek"
1171                }
1172            ]
1173        }"#;
1174
1175        let response: ListModelsResponse = serde_json::from_str(data).unwrap();
1176
1177        assert_eq!(response.data.len(), 2);
1178        assert_eq!(response.data[0].id, "deepseek-v4-flash");
1179        assert_eq!(response.data[0].owned_by, "deepseek");
1180    }
1181
1182    #[derive(Debug, Clone, PartialEq, Eq)]
1183    struct CapturedRequest {
1184        uri: String,
1185    }
1186
1187    #[derive(Clone)]
1188    enum MockResponse {
1189        Success(Bytes),
1190        Error(http::StatusCode, String),
1191    }
1192
1193    impl Default for MockResponse {
1194        fn default() -> Self {
1195            Self::Success(Bytes::new())
1196        }
1197    }
1198
1199    #[derive(Clone, Default)]
1200    struct RecordingHttpClient {
1201        requests: Arc<Mutex<Vec<CapturedRequest>>>,
1202        response: Arc<Mutex<MockResponse>>,
1203    }
1204
1205    impl RecordingHttpClient {
1206        fn new(response_body: impl Into<Bytes>) -> Self {
1207            Self {
1208                requests: Arc::new(Mutex::new(Vec::new())),
1209                response: Arc::new(Mutex::new(MockResponse::Success(response_body.into()))),
1210            }
1211        }
1212
1213        fn with_error(status: http::StatusCode, message: impl Into<String>) -> Self {
1214            Self {
1215                requests: Arc::new(Mutex::new(Vec::new())),
1216                response: Arc::new(Mutex::new(MockResponse::Error(status, message.into()))),
1217            }
1218        }
1219
1220        fn requests(&self) -> Vec<CapturedRequest> {
1221            self.requests.lock().expect("requests lock").clone()
1222        }
1223    }
1224
1225    impl HttpClientExt for RecordingHttpClient {
1226        fn send<T, U>(
1227            &self,
1228            req: HttpRequest<T>,
1229        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
1230        where
1231            T: Into<Bytes> + WasmCompatSend,
1232            U: From<Bytes> + WasmCompatSend + 'static,
1233        {
1234            let requests = Arc::clone(&self.requests);
1235            let response = self.response.lock().expect("response lock").clone();
1236            let (parts, _body) = req.into_parts();
1237
1238            requests
1239                .lock()
1240                .expect("requests lock")
1241                .push(CapturedRequest {
1242                    uri: parts.uri.to_string(),
1243                });
1244
1245            async move {
1246                let response_body = match response {
1247                    MockResponse::Success(response_body) => response_body,
1248                    MockResponse::Error(status, message) => {
1249                        return Err(http_client::Error::InvalidStatusCodeWithMessage(
1250                            status, message,
1251                        ));
1252                    }
1253                };
1254                let body: LazyBody<U> = Box::pin(async move { Ok(U::from(response_body)) });
1255                Response::builder()
1256                    .status(http::StatusCode::OK)
1257                    .body(body)
1258                    .map_err(http_client::Error::Protocol)
1259            }
1260        }
1261
1262        fn send_multipart<U>(
1263            &self,
1264            _req: HttpRequest<MultipartForm>,
1265        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
1266        where
1267            U: From<Bytes> + WasmCompatSend + 'static,
1268        {
1269            future::ready(Err(http_client::Error::InvalidStatusCode(
1270                http::StatusCode::NOT_IMPLEMENTED,
1271            )))
1272        }
1273
1274        fn send_streaming<T>(
1275            &self,
1276            _req: HttpRequest<T>,
1277        ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
1278        where
1279            T: Into<Bytes> + WasmCompatSend,
1280        {
1281            future::ready(Err(http_client::Error::InvalidStatusCode(
1282                http::StatusCode::NOT_IMPLEMENTED,
1283            )))
1284        }
1285    }
1286
1287    #[tokio::test]
1288    async fn test_list_models_uses_models_endpoint() {
1289        let response_body = r#"{
1290            "object": "list",
1291            "data": [
1292                {
1293                    "id": "deepseek-v4-flash",
1294                    "object": "model",
1295                    "owned_by": "deepseek"
1296                },
1297                {
1298                    "id": "deepseek-v4-pro",
1299                    "object": "model",
1300                    "owned_by": "deepseek"
1301                }
1302            ]
1303        }"#;
1304
1305        let http_client = RecordingHttpClient::new(response_body);
1306        let client = Client::builder()
1307            .api_key("dummy-key")
1308            .http_client(http_client.clone())
1309            .build()
1310            .expect("client should build");
1311
1312        let models = client
1313            .list_models()
1314            .await
1315            .expect("list_models should succeed");
1316
1317        assert_eq!(models.len(), 2);
1318        assert_eq!(models.data[0].id, "deepseek-v4-flash");
1319        assert_eq!(models.data[0].r#type, None);
1320        assert_eq!(models.data[0].owned_by.as_deref(), Some("deepseek"));
1321        assert_eq!(
1322            http_client.requests(),
1323            vec![CapturedRequest {
1324                uri: "https://api.deepseek.com/models".to_string()
1325            }]
1326        );
1327    }
1328
1329    #[tokio::test]
1330    async fn test_list_models_preserves_api_error_context() {
1331        let http_client = RecordingHttpClient::with_error(
1332            http::StatusCode::UNAUTHORIZED,
1333            r#"{"error":{"message":"invalid api key"}}"#,
1334        );
1335        let client = Client::builder()
1336            .api_key("dummy-key")
1337            .http_client(http_client)
1338            .build()
1339            .expect("client should build");
1340
1341        let error = client
1342            .list_models()
1343            .await
1344            .expect_err("list_models should fail");
1345
1346        match error {
1347            ModelListingError::ApiError {
1348                status_code,
1349                message,
1350            } => {
1351                assert_eq!(status_code, 401);
1352                assert!(message.contains("provider=DeepSeek"));
1353                assert!(message.contains("path=/models"));
1354                assert!(message.contains("invalid api key"));
1355            }
1356            other => panic!("expected api error, got {other:?}"),
1357        }
1358    }
1359}