Skip to main content

rig_core/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::{client::CompletionClient, providers::deepseek};
6//!
7//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
8//! let client = deepseek::Client::new("DEEPSEEK_API_KEY")?;
9//!
10//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
11//! # Ok(())
12//! # }
13//! ```
14
15use bytes::Bytes;
16use http::Request;
17use tracing::{Instrument, Level, enabled, info_span};
18
19use crate::client::{
20    self, BearerAuth, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider,
21    ProviderBuilder, ProviderClient,
22};
23use crate::completion::GetTokenUsage;
24use crate::http_client::{self, HttpClientExt};
25use crate::message::{Document, DocumentSourceKind};
26use crate::model::{Model, ModelList, ModelListingError};
27use crate::providers::internal::openai_chat_completions_compatible::{
28    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
29};
30use crate::{
31    OneOrMany,
32    completion::{self, CompletionError, CompletionRequest},
33    json_utils, message,
34    wasm_compat::{WasmCompatSend, WasmCompatSync},
35};
36use serde::{Deserialize, Serialize};
37
38use super::openai::StreamingToolCall;
39
40// ================================================================
41// Main DeepSeek Client
42// ================================================================
43const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
44
45#[derive(Debug, Default, Clone, Copy)]
46pub struct DeepSeekExt;
47#[derive(Debug, Default, Clone, Copy)]
48pub struct DeepSeekExtBuilder;
49
50type DeepSeekApiKey = BearerAuth;
51
52impl Provider for DeepSeekExt {
53    type Builder = DeepSeekExtBuilder;
54    const VERIFY_PATH: &'static str = "/user/balance";
55}
56
57impl<H> Capabilities<H> for DeepSeekExt {
58    type Completion = Capable<CompletionModel<H>>;
59    type Embeddings = Nothing;
60    type Transcription = Nothing;
61    type ModelListing = Capable<DeepSeekModelLister<H>>;
62    #[cfg(feature = "image")]
63    type ImageGeneration = Nothing;
64    #[cfg(feature = "audio")]
65    type AudioGeneration = Nothing;
66}
67
68impl DebugExt for DeepSeekExt {}
69
70impl ProviderBuilder for DeepSeekExtBuilder {
71    type Extension<H>
72        = DeepSeekExt
73    where
74        H: HttpClientExt;
75    type ApiKey = DeepSeekApiKey;
76
77    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
78
79    fn build<H>(
80        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
81    ) -> http_client::Result<Self::Extension<H>>
82    where
83        H: HttpClientExt,
84    {
85        Ok(DeepSeekExt)
86    }
87}
88
89pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
90pub type ClientBuilder<H = crate::markers::Missing> =
91    client::ClientBuilder<DeepSeekExtBuilder, String, H>;
92
93impl ProviderClient for Client {
94    type Input = DeepSeekApiKey;
95    type Error = crate::client::ProviderClientError;
96
97    // If you prefer the environment variable approach:
98    fn from_env() -> Result<Self, Self::Error> {
99        let api_key = crate::client::required_env_var("DEEPSEEK_API_KEY")?;
100        let mut client_builder = Self::builder();
101        client_builder.headers_mut().insert(
102            http::header::CONTENT_TYPE,
103            http::HeaderValue::from_static("application/json"),
104        );
105        let client_builder = client_builder.api_key(&api_key);
106        client_builder.build().map_err(Into::into)
107    }
108
109    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
110        Self::new(input).map_err(Into::into)
111    }
112}
113
114#[derive(Debug, Deserialize)]
115struct ApiErrorResponse {
116    message: String,
117}
118
119#[derive(Debug, Deserialize)]
120#[serde(untagged)]
121enum ApiResponse<T> {
122    Ok(T),
123    Err(ApiErrorResponse),
124}
125
126impl From<ApiErrorResponse> for CompletionError {
127    fn from(err: ApiErrorResponse) -> Self {
128        CompletionError::ProviderError(err.message)
129    }
130}
131
132/// The response shape from the DeepSeek API
133#[derive(Clone, Debug, Serialize, Deserialize)]
134pub struct CompletionResponse {
135    // We'll match the JSON:
136    pub choices: Vec<Choice>,
137    pub usage: Usage,
138    // you may want other fields
139}
140
141#[derive(Clone, Debug, Serialize, Deserialize, Default)]
142pub struct Usage {
143    pub completion_tokens: u32,
144    pub prompt_tokens: u32,
145    pub prompt_cache_hit_tokens: u32,
146    pub prompt_cache_miss_tokens: u32,
147    pub total_tokens: u32,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub completion_tokens_details: Option<CompletionTokensDetails>,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub prompt_tokens_details: Option<PromptTokensDetails>,
152}
153
154impl GetTokenUsage for Usage {
155    fn token_usage(&self) -> Option<crate::completion::Usage> {
156        Some(crate::providers::internal::completion_usage(
157            self.prompt_tokens as u64,
158            self.completion_tokens as u64,
159            self.total_tokens as u64,
160            self.prompt_tokens_details
161                .as_ref()
162                .and_then(|details| details.cached_tokens)
163                .map(u64::from)
164                .unwrap_or(0),
165        ))
166    }
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default)]
170pub struct CompletionTokensDetails {
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub reasoning_tokens: Option<u32>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize, Default)]
176pub struct PromptTokensDetails {
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub cached_tokens: Option<u32>,
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
182pub struct Choice {
183    pub index: usize,
184    pub message: Message,
185    pub logprobs: Option<serde_json::Value>,
186    pub finish_reason: String,
187}
188
189#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
190#[serde(tag = "role", rename_all = "lowercase")]
191pub enum Message {
192    System {
193        content: String,
194        #[serde(skip_serializing_if = "Option::is_none")]
195        name: Option<String>,
196    },
197    User {
198        content: String,
199        #[serde(skip_serializing_if = "Option::is_none")]
200        name: Option<String>,
201    },
202    Assistant {
203        content: String,
204        #[serde(skip_serializing_if = "Option::is_none")]
205        name: Option<String>,
206        #[serde(
207            default,
208            deserialize_with = "json_utils::null_or_vec",
209            skip_serializing_if = "Vec::is_empty"
210        )]
211        tool_calls: Vec<ToolCall>,
212        /// only exists on `deepseek-reasoner` model at time of addition
213        #[serde(skip_serializing_if = "Option::is_none")]
214        reasoning_content: Option<String>,
215    },
216    #[serde(rename = "tool")]
217    ToolResult {
218        tool_call_id: String,
219        content: String,
220    },
221}
222
223impl Message {
224    pub fn system(content: &str) -> Self {
225        Message::System {
226            content: content.to_owned(),
227            name: None,
228        }
229    }
230}
231
232impl From<message::ToolResult> for Message {
233    fn from(tool_result: message::ToolResult) -> Self {
234        let content = match tool_result.content.first() {
235            message::ToolResultContent::Text(text) => text.text,
236            message::ToolResultContent::Image(_) => String::from("[Image]"),
237        };
238
239        Message::ToolResult {
240            tool_call_id: tool_result.id,
241            content,
242        }
243    }
244}
245
246impl From<message::ToolCall> for ToolCall {
247    fn from(tool_call: message::ToolCall) -> Self {
248        Self {
249            id: tool_call.id,
250            // TODO: update index when we have it
251            index: 0,
252            r#type: ToolType::Function,
253            function: Function {
254                name: tool_call.function.name,
255                arguments: tool_call.function.arguments,
256            },
257        }
258    }
259}
260
261impl TryFrom<message::Message> for Vec<Message> {
262    type Error = message::MessageError;
263
264    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
265        match message {
266            message::Message::System { content } => Ok(vec![Message::System {
267                content,
268                name: None,
269            }]),
270            message::Message::User { content } => {
271                // extract tool results
272                let mut messages = vec![];
273
274                let tool_results = content
275                    .clone()
276                    .into_iter()
277                    .filter_map(|content| match content {
278                        message::UserContent::ToolResult(tool_result) => {
279                            Some(Message::from(tool_result))
280                        }
281                        _ => None,
282                    })
283                    .collect::<Vec<_>>();
284
285                messages.extend(tool_results);
286
287                let text_content: String = content
288                    .into_iter()
289                    .filter_map(|content| match content {
290                        message::UserContent::Text(text) => Some(text.text),
291                        message::UserContent::Document(Document {
292                            data:
293                                DocumentSourceKind::Base64(content)
294                                | DocumentSourceKind::String(content),
295                            ..
296                        }) => Some(content),
297                        _ => None,
298                    })
299                    .collect::<Vec<_>>()
300                    .join("\n");
301
302                if !text_content.is_empty() {
303                    messages.push(Message::User {
304                        content: text_content,
305                        name: None,
306                    });
307                }
308
309                Ok(messages)
310            }
311            message::Message::Assistant { content, .. } => {
312                let mut text_content = String::new();
313                let mut reasoning_content = String::new();
314                let mut tool_calls = Vec::new();
315
316                for item in content.iter() {
317                    match item {
318                        message::AssistantContent::Text(text) => {
319                            text_content.push_str(text.text());
320                        }
321                        message::AssistantContent::Reasoning(reasoning) => {
322                            reasoning_content.push_str(&reasoning.display_text());
323                        }
324                        message::AssistantContent::ToolCall(tool_call) => {
325                            tool_calls.push(ToolCall::from(tool_call.clone()));
326                        }
327                        _ => {}
328                    }
329                }
330
331                let reasoning = if reasoning_content.is_empty() {
332                    None
333                } else {
334                    Some(reasoning_content)
335                };
336
337                Ok(vec![Message::Assistant {
338                    content: text_content,
339                    name: None,
340                    tool_calls,
341                    reasoning_content: reasoning,
342                }])
343            }
344        }
345    }
346}
347
348#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
349pub struct ToolCall {
350    pub id: String,
351    pub index: usize,
352    #[serde(default)]
353    pub r#type: ToolType,
354    pub function: Function,
355}
356
357#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
358pub struct Function {
359    pub name: String,
360    #[serde(with = "json_utils::stringified_json")]
361    pub arguments: serde_json::Value,
362}
363
364#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
365#[serde(rename_all = "lowercase")]
366pub enum ToolType {
367    #[default]
368    Function,
369}
370
371#[derive(Clone, Debug, Deserialize, Serialize)]
372pub struct ToolDefinition {
373    pub r#type: String,
374    pub function: completion::ToolDefinition,
375}
376
377impl From<crate::completion::ToolDefinition> for ToolDefinition {
378    fn from(tool: crate::completion::ToolDefinition) -> Self {
379        Self {
380            r#type: "function".into(),
381            function: tool,
382        }
383    }
384}
385
386impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
387    type Error = CompletionError;
388
389    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
390        let choice = response.choices.first().ok_or_else(|| {
391            CompletionError::ResponseError("Response contained no choices".to_owned())
392        })?;
393        let content = match &choice.message {
394            Message::Assistant {
395                content,
396                tool_calls,
397                reasoning_content,
398                ..
399            } => {
400                let mut content = if content.trim().is_empty() {
401                    vec![]
402                } else {
403                    vec![completion::AssistantContent::text(content)]
404                };
405
406                content.extend(
407                    tool_calls
408                        .iter()
409                        .map(|call| {
410                            completion::AssistantContent::tool_call(
411                                &call.id,
412                                &call.function.name,
413                                call.function.arguments.clone(),
414                            )
415                        })
416                        .collect::<Vec<_>>(),
417                );
418
419                if let Some(reasoning_content) = reasoning_content {
420                    content.push(completion::AssistantContent::reasoning(reasoning_content));
421                }
422
423                Ok(content)
424            }
425            _ => Err(CompletionError::ResponseError(
426                "Response did not contain a valid message or tool call".into(),
427            )),
428        }?;
429
430        let choice = OneOrMany::many(content).map_err(|_| {
431            CompletionError::ResponseError(
432                "Response contained no message or tool call (empty)".to_owned(),
433            )
434        })?;
435
436        let usage = completion::Usage {
437            input_tokens: response.usage.prompt_tokens as u64,
438            output_tokens: response.usage.completion_tokens as u64,
439            total_tokens: response.usage.total_tokens as u64,
440            cached_input_tokens: response
441                .usage
442                .prompt_tokens_details
443                .as_ref()
444                .and_then(|d| d.cached_tokens)
445                .map(|c| c as u64)
446                .unwrap_or(0),
447            cache_creation_input_tokens: 0,
448            tool_use_prompt_tokens: 0,
449            reasoning_tokens: 0,
450        };
451
452        Ok(completion::CompletionResponse {
453            choice,
454            usage,
455            raw_response: response,
456            message_id: None,
457        })
458    }
459}
460
461#[derive(Debug, Serialize, Deserialize)]
462pub(super) struct DeepseekCompletionRequest {
463    model: String,
464    pub messages: Vec<Message>,
465    #[serde(skip_serializing_if = "Option::is_none")]
466    temperature: Option<f64>,
467    #[serde(skip_serializing_if = "Vec::is_empty")]
468    tools: Vec<ToolDefinition>,
469    #[serde(skip_serializing_if = "Option::is_none")]
470    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
471    #[serde(flatten, skip_serializing_if = "Option::is_none")]
472    pub additional_params: Option<serde_json::Value>,
473}
474
475impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
476    type Error = CompletionError;
477
478    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
479        if req.output_schema.is_some() {
480            tracing::warn!("Structured outputs currently not supported for DeepSeek");
481        }
482        let model = req.model.clone().unwrap_or_else(|| model.to_string());
483        let mut full_history: Vec<Message> = match &req.preamble {
484            Some(preamble) => vec![Message::system(preamble)],
485            None => vec![],
486        };
487
488        if let Some(docs) = req.normalized_documents() {
489            let docs: Vec<Message> = docs.try_into()?;
490            full_history.extend(docs);
491        }
492
493        let chat_history: Vec<Message> = req
494            .chat_history
495            .clone()
496            .into_iter()
497            .map(|message| message.try_into())
498            .collect::<Result<Vec<Vec<Message>>, _>>()?
499            .into_iter()
500            .flatten()
501            .collect();
502
503        full_history.extend(chat_history);
504
505        let tool_choice = req
506            .tool_choice
507            .clone()
508            .map(crate::providers::openrouter::ToolChoice::try_from)
509            .transpose()?;
510
511        Ok(Self {
512            model: model.to_string(),
513            messages: full_history,
514            temperature: req.temperature,
515            tools: req
516                .tools
517                .clone()
518                .into_iter()
519                .map(ToolDefinition::from)
520                .collect::<Vec<_>>(),
521            tool_choice,
522            additional_params: req.additional_params,
523        })
524    }
525}
526
527/// The struct implementing the `CompletionModel` trait
528#[derive(Clone)]
529pub struct CompletionModel<T = reqwest::Client> {
530    pub client: Client<T>,
531    pub model: String,
532}
533
534impl<T> completion::CompletionModel for CompletionModel<T>
535where
536    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
537{
538    type Response = CompletionResponse;
539    type StreamingResponse = StreamingCompletionResponse;
540
541    type Client = Client<T>;
542
543    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
544        Self {
545            client: client.clone(),
546            model: model.into().to_string(),
547        }
548    }
549
550    async fn completion(
551        &self,
552        completion_request: CompletionRequest,
553    ) -> Result<
554        completion::CompletionResponse<CompletionResponse>,
555        crate::completion::CompletionError,
556    > {
557        let span = if tracing::Span::current().is_disabled() {
558            info_span!(
559                target: "rig::completions",
560                "chat",
561                gen_ai.operation.name = "chat",
562                gen_ai.provider.name = "deepseek",
563                gen_ai.request.model = self.model,
564                gen_ai.system_instructions = tracing::field::Empty,
565                gen_ai.response.id = tracing::field::Empty,
566                gen_ai.response.model = tracing::field::Empty,
567                gen_ai.usage.output_tokens = tracing::field::Empty,
568                gen_ai.usage.input_tokens = tracing::field::Empty,
569                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
570            )
571        } else {
572            tracing::Span::current()
573        };
574
575        span.record("gen_ai.system_instructions", &completion_request.preamble);
576
577        let request =
578            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
579
580        if enabled!(Level::TRACE) {
581            tracing::trace!(target: "rig::completions",
582                "DeepSeek completion request: {}",
583                serde_json::to_string_pretty(&request)?
584            );
585        }
586
587        let body = serde_json::to_vec(&request)?;
588        let req = self
589            .client
590            .post("/chat/completions")?
591            .body(body)
592            .map_err(|e| CompletionError::HttpError(e.into()))?;
593
594        async move {
595            let response = self.client.send::<_, Bytes>(req).await?;
596            let status = response.status();
597            let response_body = response.into_body().into_future().await?.to_vec();
598
599            if status.is_success() {
600                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
601                    ApiResponse::Ok(response) => {
602                        let span = tracing::Span::current();
603                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
604                        span.record(
605                            "gen_ai.usage.output_tokens",
606                            response.usage.completion_tokens,
607                        );
608                        span.record(
609                            "gen_ai.usage.cache_read.input_tokens",
610                            response
611                                .usage
612                                .prompt_tokens_details
613                                .as_ref()
614                                .and_then(|d| d.cached_tokens)
615                                .unwrap_or(0),
616                        );
617                        if enabled!(Level::TRACE) {
618                            tracing::trace!(target: "rig::completions",
619                                "DeepSeek completion response: {}",
620                                serde_json::to_string_pretty(&response)?
621                            );
622                        }
623                        response.try_into()
624                    }
625                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
626                }
627            } else {
628                Err(CompletionError::ProviderError(
629                    String::from_utf8_lossy(&response_body).to_string(),
630                ))
631            }
632        }
633        .instrument(span)
634        .await
635    }
636
637    async fn stream(
638        &self,
639        completion_request: CompletionRequest,
640    ) -> Result<
641        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
642        CompletionError,
643    > {
644        let preamble = completion_request.preamble.clone();
645        let mut request =
646            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
647
648        let params = json_utils::merge(
649            request.additional_params.unwrap_or(serde_json::json!({})),
650            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
651        );
652
653        request.additional_params = Some(params);
654
655        if enabled!(Level::TRACE) {
656            tracing::trace!(target: "rig::completions",
657                "DeepSeek streaming completion request: {}",
658                serde_json::to_string_pretty(&request)?
659            );
660        }
661
662        let body = serde_json::to_vec(&request)?;
663
664        let req = self
665            .client
666            .post("/chat/completions")?
667            .body(body)
668            .map_err(|e| CompletionError::HttpError(e.into()))?;
669
670        let span = if tracing::Span::current().is_disabled() {
671            info_span!(
672                target: "rig::completions",
673                "chat_streaming",
674                gen_ai.operation.name = "chat_streaming",
675                gen_ai.provider.name = "deepseek",
676                gen_ai.request.model = self.model,
677                gen_ai.system_instructions = preamble,
678                gen_ai.response.id = tracing::field::Empty,
679                gen_ai.response.model = tracing::field::Empty,
680                gen_ai.usage.output_tokens = tracing::field::Empty,
681                gen_ai.usage.input_tokens = tracing::field::Empty,
682                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
683            )
684        } else {
685            tracing::Span::current()
686        };
687
688        tracing::Instrument::instrument(
689            send_compatible_streaming_request(self.client.clone(), req),
690            span,
691        )
692        .await
693    }
694}
695
696#[derive(Deserialize, Debug)]
697pub struct StreamingDelta {
698    #[serde(default)]
699    content: Option<String>,
700    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
701    tool_calls: Vec<StreamingToolCall>,
702    reasoning_content: Option<String>,
703}
704
705#[derive(Deserialize, Debug)]
706struct StreamingChoice {
707    delta: StreamingDelta,
708}
709
710#[derive(Deserialize, Debug)]
711struct StreamingCompletionChunk {
712    id: Option<String>,
713    model: Option<String>,
714    choices: Vec<StreamingChoice>,
715    usage: Option<Usage>,
716}
717
718#[derive(Clone, Deserialize, Serialize, Debug)]
719pub struct StreamingCompletionResponse {
720    pub usage: Usage,
721}
722
723impl GetTokenUsage for StreamingCompletionResponse {
724    fn token_usage(&self) -> Option<crate::completion::Usage> {
725        self.usage.token_usage()
726    }
727}
728
729#[derive(Clone, Copy)]
730struct DeepSeekCompatibleProfile;
731
732impl CompatibleStreamProfile for DeepSeekCompatibleProfile {
733    type Usage = Usage;
734    type Detail = ();
735    type FinalResponse = StreamingCompletionResponse;
736
737    fn normalize_chunk(
738        &self,
739        data: &str,
740    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
741        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
742            Ok(data) => data,
743            Err(error) => {
744                tracing::debug!(
745                    "Couldn't parse SSE payload as StreamingCompletionChunk: {:?}",
746                    error
747                );
748                return Ok(None);
749            }
750        };
751
752        Ok(Some(
753            openai_chat_completions_compatible::normalize_first_choice_chunk(
754                data.id,
755                data.model,
756                data.usage,
757                &data.choices,
758                |choice| CompatibleChoiceData {
759                    finish_reason: CompatibleFinishReason::Other,
760                    text: choice.delta.content.clone(),
761                    reasoning: choice.delta.reasoning_content.clone(),
762                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
763                        &choice.delta.tool_calls,
764                    ),
765                    details: Vec::new(),
766                },
767            ),
768        ))
769    }
770
771    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
772        StreamingCompletionResponse { usage }
773    }
774
775    fn uses_distinct_tool_call_eviction(&self) -> bool {
776        true
777    }
778
779    fn emits_complete_single_chunk_tool_calls(&self) -> bool {
780        true
781    }
782}
783
784pub async fn send_compatible_streaming_request<T>(
785    http_client: T,
786    req: Request<Vec<u8>>,
787) -> Result<
788    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
789    CompletionError,
790>
791where
792    T: HttpClientExt + Clone + 'static,
793{
794    openai_chat_completions_compatible::send_compatible_streaming_request(
795        http_client,
796        req,
797        DeepSeekCompatibleProfile,
798    )
799    .await
800}
801
802#[derive(Debug, Deserialize)]
803struct ListModelsResponse {
804    data: Vec<ListModelEntry>,
805}
806
807#[derive(Debug, Deserialize)]
808struct ListModelEntry {
809    id: String,
810    owned_by: String,
811}
812
813impl From<ListModelEntry> for Model {
814    fn from(value: ListModelEntry) -> Self {
815        let mut model = Model::from_id(value.id);
816        model.owned_by = Some(value.owned_by);
817        model
818    }
819}
820
821/// [`ModelLister`] implementation for the DeepSeek API (`GET /models`).
822#[derive(Clone)]
823pub struct DeepSeekModelLister<H = reqwest::Client> {
824    client: Client<H>,
825}
826
827impl<H> ModelLister<H> for DeepSeekModelLister<H>
828where
829    H: HttpClientExt + WasmCompatSend + WasmCompatSync + 'static,
830{
831    type Client = Client<H>;
832
833    fn new(client: Self::Client) -> Self {
834        Self { client }
835    }
836
837    async fn list_all(&self) -> Result<ModelList, ModelListingError> {
838        let path = "/models";
839        let req = self.client.get(path)?.body(http_client::NoBody)?;
840        let response = self
841            .client
842            .send::<_, Vec<u8>>(req)
843            .await
844            .map_err(|error| match error {
845                http_client::Error::InvalidStatusCodeWithMessage(status, message) => {
846                    ModelListingError::api_error_with_context(
847                        "DeepSeek",
848                        path,
849                        status.as_u16(),
850                        message.as_bytes(),
851                    )
852                }
853                other => ModelListingError::from(other),
854            })?;
855
856        if !response.status().is_success() {
857            let status_code = response.status().as_u16();
858            let body = response.into_body().await?;
859            return Err(ModelListingError::api_error_with_context(
860                "DeepSeek",
861                path,
862                status_code,
863                &body,
864            ));
865        }
866
867        let body = response.into_body().await?;
868        let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
869            ModelListingError::parse_error_with_context("DeepSeek", path, &error, &body)
870        })?;
871
872        let models = api_resp.data.into_iter().map(Model::from).collect();
873
874        Ok(ModelList::new(models))
875    }
876}
877
878// ================================================================
879// DeepSeek Completion API
880// ================================================================
881#[deprecated(
882    note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
883    For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
884    respectively."
885)]
886pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
887#[deprecated(
888    note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
889    For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
890    respectively."
891)]
892pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
893pub const DEEPSEEK_V4_FLASH: &str = "deepseek-v4-flash";
894pub const DEEPSEEK_V4_PRO: &str = "deepseek-v4-pro";
895
896// Tests
897#[cfg(test)]
898mod tests {
899    use super::*;
900    use crate::client::ModelListingClient;
901    use crate::test_utils::RecordingHttpClient;
902
903    #[test]
904    fn test_deserialize_vec_choice() {
905        let data = r#"[{
906            "finish_reason": "stop",
907            "index": 0,
908            "logprobs": null,
909            "message":{"role":"assistant","content":"Hello, world!"}
910            }]"#;
911
912        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
913        assert_eq!(choices.len(), 1);
914        match &choices.first().unwrap().message {
915            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
916            _ => panic!("Expected assistant message"),
917        }
918    }
919
920    #[test]
921    fn test_deserialize_deepseek_response() {
922        let data = r#"{
923            "choices":[{
924                "finish_reason": "stop",
925                "index": 0,
926                "logprobs": null,
927                "message":{"role":"assistant","content":"Hello, world!"}
928            }],
929            "usage": {
930                "completion_tokens": 0,
931                "prompt_tokens": 0,
932                "prompt_cache_hit_tokens": 0,
933                "prompt_cache_miss_tokens": 0,
934                "total_tokens": 0
935            }
936        }"#;
937
938        let jd = &mut serde_json::Deserializer::from_str(data);
939        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
940        match result {
941            Ok(response) => match &response.choices.first().unwrap().message {
942                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
943                _ => panic!("Expected assistant message"),
944            },
945            Err(err) => {
946                panic!("Deserialization error at {}: {}", err.path(), err);
947            }
948        }
949    }
950
951    #[test]
952    fn test_deserialize_example_response() {
953        let data = r#"
954        {
955            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
956            "object": "chat.completion",
957            "created": 0,
958            "model": "deepseek-chat",
959            "choices": [
960                {
961                    "index": 0,
962                    "message": {
963                        "role": "assistant",
964                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
965                    },
966                    "logprobs": null,
967                    "finish_reason": "stop"
968                }
969            ],
970            "usage": {
971                "prompt_tokens": 13,
972                "completion_tokens": 32,
973                "total_tokens": 45,
974                "prompt_tokens_details": {
975                    "cached_tokens": 0
976                },
977                "prompt_cache_hit_tokens": 0,
978                "prompt_cache_miss_tokens": 13
979            },
980            "system_fingerprint": "fp_4b6881f2c5"
981        }
982        "#;
983        let jd = &mut serde_json::Deserializer::from_str(data);
984        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
985
986        match result {
987            Ok(response) => match &response.choices.first().unwrap().message {
988                Message::Assistant { content, .. } => assert_eq!(
989                    content,
990                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
991                ),
992                _ => panic!("Expected assistant message"),
993            },
994            Err(err) => {
995                panic!("Deserialization error at {}: {}", err.path(), err);
996            }
997        }
998    }
999
1000    #[test]
1001    fn test_serialize_deserialize_tool_call_message() {
1002        let tool_call_choice_json = r#"
1003            {
1004              "finish_reason": "tool_calls",
1005              "index": 0,
1006              "logprobs": null,
1007              "message": {
1008                "content": "",
1009                "role": "assistant",
1010                "tool_calls": [
1011                  {
1012                    "function": {
1013                      "arguments": "{\"x\":2,\"y\":5}",
1014                      "name": "subtract"
1015                    },
1016                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1017                    "index": 0,
1018                    "type": "function"
1019                  }
1020                ]
1021              }
1022            }
1023        "#;
1024
1025        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1026
1027        let expected_choice: Choice = Choice {
1028            finish_reason: "tool_calls".to_string(),
1029            index: 0,
1030            logprobs: None,
1031            message: Message::Assistant {
1032                content: "".to_string(),
1033                name: None,
1034                tool_calls: vec![ToolCall {
1035                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1036                    function: Function {
1037                        name: "subtract".to_string(),
1038                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1039                    },
1040                    index: 0,
1041                    r#type: ToolType::Function,
1042                }],
1043                reasoning_content: None,
1044            },
1045        };
1046
1047        assert_eq!(choice, expected_choice);
1048    }
1049    #[test]
1050    fn test_user_message_multiple_text_items_merged() {
1051        use crate::completion::message::{Message as RigMessage, UserContent};
1052
1053        let rig_msg = RigMessage::User {
1054            content: OneOrMany::many(vec![
1055                UserContent::text("first part"),
1056                UserContent::text("second part"),
1057            ])
1058            .expect("content should not be empty"),
1059        };
1060
1061        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1062
1063        let user_messages: Vec<&Message> = messages
1064            .iter()
1065            .filter(|m| matches!(m, Message::User { .. }))
1066            .collect();
1067
1068        assert_eq!(
1069            user_messages.len(),
1070            1,
1071            "multiple text items should produce a single user message"
1072        );
1073        match &user_messages[0] {
1074            Message::User { content, .. } => {
1075                assert_eq!(content, "first part\nsecond part");
1076            }
1077            _ => unreachable!(),
1078        }
1079    }
1080
1081    #[test]
1082    fn test_assistant_message_with_reasoning_and_tool_calls() {
1083        use crate::completion::message::{AssistantContent, Message as RigMessage};
1084
1085        let rig_msg = RigMessage::Assistant {
1086            id: None,
1087            content: OneOrMany::many(vec![
1088                AssistantContent::reasoning("thinking about the problem"),
1089                AssistantContent::text("I'll call the tool"),
1090                AssistantContent::tool_call(
1091                    "call_1",
1092                    "subtract",
1093                    serde_json::json!({"x": 2, "y": 5}),
1094                ),
1095            ])
1096            .expect("content should not be empty"),
1097        };
1098
1099        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1100
1101        assert_eq!(messages.len(), 1, "should produce exactly one message");
1102        match &messages[0] {
1103            Message::Assistant {
1104                content,
1105                tool_calls,
1106                reasoning_content,
1107                ..
1108            } => {
1109                assert_eq!(content, "I'll call the tool");
1110                assert_eq!(
1111                    reasoning_content.as_deref(),
1112                    Some("thinking about the problem")
1113                );
1114                assert_eq!(tool_calls.len(), 1);
1115                assert_eq!(tool_calls[0].function.name, "subtract");
1116            }
1117            _ => panic!("Expected assistant message"),
1118        }
1119    }
1120
1121    #[test]
1122    fn test_assistant_message_without_reasoning() {
1123        use crate::completion::message::{AssistantContent, Message as RigMessage};
1124
1125        let rig_msg = RigMessage::Assistant {
1126            id: None,
1127            content: OneOrMany::many(vec![
1128                AssistantContent::text("calling tool"),
1129                AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1130            ])
1131            .expect("content should not be empty"),
1132        };
1133
1134        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1135
1136        assert_eq!(messages.len(), 1);
1137        match &messages[0] {
1138            Message::Assistant {
1139                reasoning_content,
1140                tool_calls,
1141                ..
1142            } => {
1143                assert!(reasoning_content.is_none());
1144                assert_eq!(tool_calls.len(), 1);
1145            }
1146            _ => panic!("Expected assistant message"),
1147        }
1148    }
1149
1150    #[test]
1151    fn test_client_initialization() {
1152        let _client =
1153            crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1154        let _client_from_builder = crate::providers::deepseek::Client::builder()
1155            .api_key("dummy-key")
1156            .build()
1157            .expect("Client::builder() failed");
1158    }
1159
1160    #[test]
1161    fn test_deserialize_list_models_response() {
1162        let data = r#"{
1163            "object": "list",
1164            "data": [
1165                {
1166                    "id": "deepseek-v4-flash",
1167                    "object": "model",
1168                    "owned_by": "deepseek"
1169                },
1170                {
1171                    "id": "deepseek-v4-pro",
1172                    "object": "model",
1173                    "owned_by": "deepseek"
1174                }
1175            ]
1176        }"#;
1177
1178        let response: ListModelsResponse = serde_json::from_str(data).unwrap();
1179
1180        assert_eq!(response.data.len(), 2);
1181        assert_eq!(response.data[0].id, "deepseek-v4-flash");
1182        assert_eq!(response.data[0].owned_by, "deepseek");
1183    }
1184
1185    #[tokio::test]
1186    async fn test_list_models_uses_models_endpoint() {
1187        let response_body = r#"{
1188            "object": "list",
1189            "data": [
1190                {
1191                    "id": "deepseek-v4-flash",
1192                    "object": "model",
1193                    "owned_by": "deepseek"
1194                },
1195                {
1196                    "id": "deepseek-v4-pro",
1197                    "object": "model",
1198                    "owned_by": "deepseek"
1199                }
1200            ]
1201        }"#;
1202
1203        let http_client = RecordingHttpClient::new(response_body);
1204        let client = Client::builder()
1205            .api_key("dummy-key")
1206            .http_client(http_client.clone())
1207            .build()
1208            .expect("client should build");
1209
1210        let models = client
1211            .list_models()
1212            .await
1213            .expect("list_models should succeed");
1214
1215        assert_eq!(models.len(), 2);
1216        assert_eq!(models.data[0].id, "deepseek-v4-flash");
1217        assert_eq!(models.data[0].r#type, None);
1218        assert_eq!(models.data[0].owned_by.as_deref(), Some("deepseek"));
1219        let requests = http_client.requests();
1220        assert_eq!(requests.len(), 1);
1221        assert_eq!(requests[0].uri, "https://api.deepseek.com/models");
1222    }
1223
1224    #[tokio::test]
1225    async fn test_list_models_preserves_api_error_context() {
1226        let http_client = RecordingHttpClient::with_error(
1227            http::StatusCode::UNAUTHORIZED,
1228            r#"{"error":{"message":"invalid api key"}}"#,
1229        );
1230        let client = Client::builder()
1231            .api_key("dummy-key")
1232            .http_client(http_client)
1233            .build()
1234            .expect("client should build");
1235
1236        let error = client
1237            .list_models()
1238            .await
1239            .expect_err("list_models should fail");
1240
1241        match error {
1242            ModelListingError::ApiError {
1243                status_code,
1244                message,
1245            } => {
1246                assert_eq!(status_code, 401);
1247                assert!(message.contains("provider=DeepSeek"));
1248                assert!(message.contains("path=/models"));
1249                assert!(message.contains("invalid api key"));
1250            }
1251            other => panic!("expected api error, got {other:?}"),
1252        }
1253    }
1254}