Skip to main content

rig_core/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::{client::CompletionClient, providers::deepseek};
6//!
7//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
8//! let client = deepseek::Client::new("DEEPSEEK_API_KEY")?;
9//!
10//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
11//! # Ok(())
12//! # }
13//! ```
14
15use bytes::Bytes;
16use http::Request;
17use tracing::{Instrument, Level, enabled, info_span};
18
19use crate::client::{
20    self, BearerAuth, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider,
21    ProviderBuilder, ProviderClient,
22};
23use crate::completion::GetTokenUsage;
24use crate::http_client::{self, HttpClientExt};
25use crate::message::{Document, DocumentSourceKind};
26use crate::model::{Model, ModelList, ModelListingError};
27use crate::providers::internal::openai_chat_completions_compatible::{
28    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
29};
30use crate::{
31    OneOrMany,
32    completion::{self, CompletionError, CompletionRequest},
33    json_utils, message,
34    wasm_compat::{WasmCompatSend, WasmCompatSync},
35};
36use serde::{Deserialize, Serialize};
37
38use super::openai::StreamingToolCall;
39
40// ================================================================
41// Main DeepSeek Client
42// ================================================================
43const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
44
45#[derive(Debug, Default, Clone, Copy)]
46pub struct DeepSeekExt;
47#[derive(Debug, Default, Clone, Copy)]
48pub struct DeepSeekExtBuilder;
49
50type DeepSeekApiKey = BearerAuth;
51
52impl Provider for DeepSeekExt {
53    type Builder = DeepSeekExtBuilder;
54    const VERIFY_PATH: &'static str = "/user/balance";
55}
56
57impl<H> Capabilities<H> for DeepSeekExt {
58    type Completion = Capable<CompletionModel<H>>;
59    type Embeddings = Nothing;
60    type Transcription = Nothing;
61    type ModelListing = Capable<DeepSeekModelLister<H>>;
62    #[cfg(feature = "image")]
63    type ImageGeneration = Nothing;
64    #[cfg(feature = "audio")]
65    type AudioGeneration = Nothing;
66}
67
68impl DebugExt for DeepSeekExt {}
69
70impl ProviderBuilder for DeepSeekExtBuilder {
71    type Extension<H>
72        = DeepSeekExt
73    where
74        H: HttpClientExt;
75    type ApiKey = DeepSeekApiKey;
76
77    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
78
79    fn build<H>(
80        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
81    ) -> http_client::Result<Self::Extension<H>>
82    where
83        H: HttpClientExt,
84    {
85        Ok(DeepSeekExt)
86    }
87}
88
89pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
90pub type ClientBuilder<H = crate::markers::Missing> =
91    client::ClientBuilder<DeepSeekExtBuilder, String, H>;
92
93impl ProviderClient for Client {
94    type Input = DeepSeekApiKey;
95    type Error = crate::client::ProviderClientError;
96
97    // If you prefer the environment variable approach:
98    fn from_env() -> Result<Self, Self::Error> {
99        let api_key = crate::client::required_env_var("DEEPSEEK_API_KEY")?;
100        let mut client_builder = Self::builder();
101        client_builder.headers_mut().insert(
102            http::header::CONTENT_TYPE,
103            http::HeaderValue::from_static("application/json"),
104        );
105        let client_builder = client_builder.api_key(&api_key);
106        client_builder.build().map_err(Into::into)
107    }
108
109    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
110        Self::new(input).map_err(Into::into)
111    }
112}
113
114#[derive(Debug, Deserialize)]
115struct ApiErrorResponse {
116    message: String,
117}
118
119#[derive(Debug, Deserialize)]
120#[serde(untagged)]
121enum ApiResponse<T> {
122    Ok(T),
123    Err(ApiErrorResponse),
124}
125
126impl From<ApiErrorResponse> for CompletionError {
127    fn from(err: ApiErrorResponse) -> Self {
128        CompletionError::ProviderError(err.message)
129    }
130}
131
132/// The response shape from the DeepSeek API
133#[derive(Clone, Debug, Serialize, Deserialize)]
134pub struct CompletionResponse {
135    // We'll match the JSON:
136    pub choices: Vec<Choice>,
137    pub usage: Usage,
138    // you may want other fields
139}
140
141#[derive(Clone, Debug, Serialize, Deserialize, Default)]
142pub struct Usage {
143    pub completion_tokens: u32,
144    pub prompt_tokens: u32,
145    pub prompt_cache_hit_tokens: u32,
146    pub prompt_cache_miss_tokens: u32,
147    pub total_tokens: u32,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub completion_tokens_details: Option<CompletionTokensDetails>,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub prompt_tokens_details: Option<PromptTokensDetails>,
152}
153
154impl GetTokenUsage for Usage {
155    fn token_usage(&self) -> Option<crate::completion::Usage> {
156        Some(crate::providers::internal::completion_usage(
157            self.prompt_tokens as u64,
158            self.completion_tokens as u64,
159            self.total_tokens as u64,
160            self.prompt_tokens_details
161                .as_ref()
162                .and_then(|details| details.cached_tokens)
163                .map(u64::from)
164                .unwrap_or(0),
165        ))
166    }
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default)]
170pub struct CompletionTokensDetails {
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub reasoning_tokens: Option<u32>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize, Default)]
176pub struct PromptTokensDetails {
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub cached_tokens: Option<u32>,
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
182pub struct Choice {
183    pub index: usize,
184    pub message: Message,
185    pub logprobs: Option<serde_json::Value>,
186    pub finish_reason: String,
187}
188
189#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
190#[serde(tag = "role", rename_all = "lowercase")]
191pub enum Message {
192    System {
193        content: String,
194        #[serde(skip_serializing_if = "Option::is_none")]
195        name: Option<String>,
196    },
197    User {
198        content: String,
199        #[serde(skip_serializing_if = "Option::is_none")]
200        name: Option<String>,
201    },
202    Assistant {
203        content: String,
204        #[serde(skip_serializing_if = "Option::is_none")]
205        name: Option<String>,
206        #[serde(
207            default,
208            deserialize_with = "json_utils::null_or_vec",
209            skip_serializing_if = "Vec::is_empty"
210        )]
211        tool_calls: Vec<ToolCall>,
212        /// only exists on `deepseek-reasoner` model at time of addition
213        #[serde(skip_serializing_if = "Option::is_none")]
214        reasoning_content: Option<String>,
215    },
216    #[serde(rename = "tool")]
217    ToolResult {
218        tool_call_id: String,
219        content: String,
220    },
221}
222
223impl Message {
224    pub fn system(content: &str) -> Self {
225        Message::System {
226            content: content.to_owned(),
227            name: None,
228        }
229    }
230}
231
232impl From<message::ToolResult> for Message {
233    fn from(tool_result: message::ToolResult) -> Self {
234        let content = match tool_result.content.first() {
235            message::ToolResultContent::Text(text) => text.text,
236            message::ToolResultContent::Image(_) => String::from("[Image]"),
237        };
238
239        Message::ToolResult {
240            tool_call_id: tool_result.id,
241            content,
242        }
243    }
244}
245
246impl From<message::ToolCall> for ToolCall {
247    fn from(tool_call: message::ToolCall) -> Self {
248        Self {
249            id: tool_call.id,
250            // TODO: update index when we have it
251            index: 0,
252            r#type: ToolType::Function,
253            function: Function {
254                name: tool_call.function.name,
255                arguments: tool_call.function.arguments,
256            },
257        }
258    }
259}
260
261impl TryFrom<message::Message> for Vec<Message> {
262    type Error = message::MessageError;
263
264    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
265        match message {
266            message::Message::System { content } => Ok(vec![Message::System {
267                content,
268                name: None,
269            }]),
270            message::Message::User { content } => {
271                // extract tool results
272                let mut messages = vec![];
273
274                let tool_results = content
275                    .clone()
276                    .into_iter()
277                    .filter_map(|content| match content {
278                        message::UserContent::ToolResult(tool_result) => {
279                            Some(Message::from(tool_result))
280                        }
281                        _ => None,
282                    })
283                    .collect::<Vec<_>>();
284
285                messages.extend(tool_results);
286
287                let text_content: String = content
288                    .into_iter()
289                    .filter_map(|content| match content {
290                        message::UserContent::Text(text) => Some(text.text),
291                        message::UserContent::Document(Document {
292                            data:
293                                DocumentSourceKind::Base64(content)
294                                | DocumentSourceKind::String(content),
295                            ..
296                        }) => Some(content),
297                        _ => None,
298                    })
299                    .collect::<Vec<_>>()
300                    .join("\n");
301
302                if !text_content.is_empty() {
303                    messages.push(Message::User {
304                        content: text_content,
305                        name: None,
306                    });
307                }
308
309                Ok(messages)
310            }
311            message::Message::Assistant { content, .. } => {
312                let mut text_content = String::new();
313                let mut reasoning_content = String::new();
314                let mut tool_calls = Vec::new();
315
316                for item in content.iter() {
317                    match item {
318                        message::AssistantContent::Text(text) => {
319                            text_content.push_str(text.text());
320                        }
321                        message::AssistantContent::Reasoning(reasoning) => {
322                            reasoning_content.push_str(&reasoning.display_text());
323                        }
324                        message::AssistantContent::ToolCall(tool_call) => {
325                            tool_calls.push(ToolCall::from(tool_call.clone()));
326                        }
327                        _ => {}
328                    }
329                }
330
331                let reasoning = if reasoning_content.is_empty() {
332                    None
333                } else {
334                    Some(reasoning_content)
335                };
336
337                Ok(vec![Message::Assistant {
338                    content: text_content,
339                    name: None,
340                    tool_calls,
341                    reasoning_content: reasoning,
342                }])
343            }
344        }
345    }
346}
347
348#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
349pub struct ToolCall {
350    pub id: String,
351    pub index: usize,
352    #[serde(default)]
353    pub r#type: ToolType,
354    pub function: Function,
355}
356
357#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
358pub struct Function {
359    pub name: String,
360    #[serde(with = "json_utils::stringified_json")]
361    pub arguments: serde_json::Value,
362}
363
364#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
365#[serde(rename_all = "lowercase")]
366pub enum ToolType {
367    #[default]
368    Function,
369}
370
371#[derive(Clone, Debug, Deserialize, Serialize)]
372pub struct ToolDefinition {
373    pub r#type: String,
374    pub function: completion::ToolDefinition,
375}
376
377impl From<crate::completion::ToolDefinition> for ToolDefinition {
378    fn from(tool: crate::completion::ToolDefinition) -> Self {
379        Self {
380            r#type: "function".into(),
381            function: tool,
382        }
383    }
384}
385
386impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
387    type Error = CompletionError;
388
389    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
390        let choice = response.choices.first().ok_or_else(|| {
391            CompletionError::ResponseError("Response contained no choices".to_owned())
392        })?;
393        let content = match &choice.message {
394            Message::Assistant {
395                content,
396                tool_calls,
397                reasoning_content,
398                ..
399            } => {
400                let mut content = if content.trim().is_empty() {
401                    vec![]
402                } else {
403                    vec![completion::AssistantContent::text(content)]
404                };
405
406                content.extend(
407                    tool_calls
408                        .iter()
409                        .map(|call| {
410                            completion::AssistantContent::tool_call(
411                                &call.id,
412                                &call.function.name,
413                                call.function.arguments.clone(),
414                            )
415                        })
416                        .collect::<Vec<_>>(),
417                );
418
419                if let Some(reasoning_content) = reasoning_content {
420                    content.push(completion::AssistantContent::reasoning(reasoning_content));
421                }
422
423                Ok(content)
424            }
425            _ => Err(CompletionError::ResponseError(
426                "Response did not contain a valid message or tool call".into(),
427            )),
428        }?;
429
430        let choice = OneOrMany::many(content).map_err(|_| {
431            CompletionError::ResponseError(
432                "Response contained no message or tool call (empty)".to_owned(),
433            )
434        })?;
435
436        let usage = completion::Usage {
437            input_tokens: response.usage.prompt_tokens as u64,
438            output_tokens: response.usage.completion_tokens as u64,
439            total_tokens: response.usage.total_tokens as u64,
440            cached_input_tokens: response
441                .usage
442                .prompt_tokens_details
443                .as_ref()
444                .and_then(|d| d.cached_tokens)
445                .map(|c| c as u64)
446                .unwrap_or(0),
447            cache_creation_input_tokens: 0,
448            reasoning_tokens: 0,
449        };
450
451        Ok(completion::CompletionResponse {
452            choice,
453            usage,
454            raw_response: response,
455            message_id: None,
456        })
457    }
458}
459
460#[derive(Debug, Serialize, Deserialize)]
461pub(super) struct DeepseekCompletionRequest {
462    model: String,
463    pub messages: Vec<Message>,
464    #[serde(skip_serializing_if = "Option::is_none")]
465    temperature: Option<f64>,
466    #[serde(skip_serializing_if = "Vec::is_empty")]
467    tools: Vec<ToolDefinition>,
468    #[serde(skip_serializing_if = "Option::is_none")]
469    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
470    #[serde(flatten, skip_serializing_if = "Option::is_none")]
471    pub additional_params: Option<serde_json::Value>,
472}
473
474impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
475    type Error = CompletionError;
476
477    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
478        if req.output_schema.is_some() {
479            tracing::warn!("Structured outputs currently not supported for DeepSeek");
480        }
481        let model = req.model.clone().unwrap_or_else(|| model.to_string());
482        let mut full_history: Vec<Message> = match &req.preamble {
483            Some(preamble) => vec![Message::system(preamble)],
484            None => vec![],
485        };
486
487        if let Some(docs) = req.normalized_documents() {
488            let docs: Vec<Message> = docs.try_into()?;
489            full_history.extend(docs);
490        }
491
492        let chat_history: Vec<Message> = req
493            .chat_history
494            .clone()
495            .into_iter()
496            .map(|message| message.try_into())
497            .collect::<Result<Vec<Vec<Message>>, _>>()?
498            .into_iter()
499            .flatten()
500            .collect();
501
502        full_history.extend(chat_history);
503
504        let tool_choice = req
505            .tool_choice
506            .clone()
507            .map(crate::providers::openrouter::ToolChoice::try_from)
508            .transpose()?;
509
510        Ok(Self {
511            model: model.to_string(),
512            messages: full_history,
513            temperature: req.temperature,
514            tools: req
515                .tools
516                .clone()
517                .into_iter()
518                .map(ToolDefinition::from)
519                .collect::<Vec<_>>(),
520            tool_choice,
521            additional_params: req.additional_params,
522        })
523    }
524}
525
526/// The struct implementing the `CompletionModel` trait
527#[derive(Clone)]
528pub struct CompletionModel<T = reqwest::Client> {
529    pub client: Client<T>,
530    pub model: String,
531}
532
533impl<T> completion::CompletionModel for CompletionModel<T>
534where
535    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
536{
537    type Response = CompletionResponse;
538    type StreamingResponse = StreamingCompletionResponse;
539
540    type Client = Client<T>;
541
542    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
543        Self {
544            client: client.clone(),
545            model: model.into().to_string(),
546        }
547    }
548
549    async fn completion(
550        &self,
551        completion_request: CompletionRequest,
552    ) -> Result<
553        completion::CompletionResponse<CompletionResponse>,
554        crate::completion::CompletionError,
555    > {
556        let span = if tracing::Span::current().is_disabled() {
557            info_span!(
558                target: "rig::completions",
559                "chat",
560                gen_ai.operation.name = "chat",
561                gen_ai.provider.name = "deepseek",
562                gen_ai.request.model = self.model,
563                gen_ai.system_instructions = tracing::field::Empty,
564                gen_ai.response.id = tracing::field::Empty,
565                gen_ai.response.model = tracing::field::Empty,
566                gen_ai.usage.output_tokens = tracing::field::Empty,
567                gen_ai.usage.input_tokens = tracing::field::Empty,
568                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
569            )
570        } else {
571            tracing::Span::current()
572        };
573
574        span.record("gen_ai.system_instructions", &completion_request.preamble);
575
576        let request =
577            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
578
579        if enabled!(Level::TRACE) {
580            tracing::trace!(target: "rig::completions",
581                "DeepSeek completion request: {}",
582                serde_json::to_string_pretty(&request)?
583            );
584        }
585
586        let body = serde_json::to_vec(&request)?;
587        let req = self
588            .client
589            .post("/chat/completions")?
590            .body(body)
591            .map_err(|e| CompletionError::HttpError(e.into()))?;
592
593        async move {
594            let response = self.client.send::<_, Bytes>(req).await?;
595            let status = response.status();
596            let response_body = response.into_body().into_future().await?.to_vec();
597
598            if status.is_success() {
599                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
600                    ApiResponse::Ok(response) => {
601                        let span = tracing::Span::current();
602                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
603                        span.record(
604                            "gen_ai.usage.output_tokens",
605                            response.usage.completion_tokens,
606                        );
607                        span.record(
608                            "gen_ai.usage.cache_read.input_tokens",
609                            response
610                                .usage
611                                .prompt_tokens_details
612                                .as_ref()
613                                .and_then(|d| d.cached_tokens)
614                                .unwrap_or(0),
615                        );
616                        if enabled!(Level::TRACE) {
617                            tracing::trace!(target: "rig::completions",
618                                "DeepSeek completion response: {}",
619                                serde_json::to_string_pretty(&response)?
620                            );
621                        }
622                        response.try_into()
623                    }
624                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
625                }
626            } else {
627                Err(CompletionError::ProviderError(
628                    String::from_utf8_lossy(&response_body).to_string(),
629                ))
630            }
631        }
632        .instrument(span)
633        .await
634    }
635
636    async fn stream(
637        &self,
638        completion_request: CompletionRequest,
639    ) -> Result<
640        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
641        CompletionError,
642    > {
643        let preamble = completion_request.preamble.clone();
644        let mut request =
645            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
646
647        let params = json_utils::merge(
648            request.additional_params.unwrap_or(serde_json::json!({})),
649            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
650        );
651
652        request.additional_params = Some(params);
653
654        if enabled!(Level::TRACE) {
655            tracing::trace!(target: "rig::completions",
656                "DeepSeek streaming completion request: {}",
657                serde_json::to_string_pretty(&request)?
658            );
659        }
660
661        let body = serde_json::to_vec(&request)?;
662
663        let req = self
664            .client
665            .post("/chat/completions")?
666            .body(body)
667            .map_err(|e| CompletionError::HttpError(e.into()))?;
668
669        let span = if tracing::Span::current().is_disabled() {
670            info_span!(
671                target: "rig::completions",
672                "chat_streaming",
673                gen_ai.operation.name = "chat_streaming",
674                gen_ai.provider.name = "deepseek",
675                gen_ai.request.model = self.model,
676                gen_ai.system_instructions = preamble,
677                gen_ai.response.id = tracing::field::Empty,
678                gen_ai.response.model = tracing::field::Empty,
679                gen_ai.usage.output_tokens = tracing::field::Empty,
680                gen_ai.usage.input_tokens = tracing::field::Empty,
681                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
682            )
683        } else {
684            tracing::Span::current()
685        };
686
687        tracing::Instrument::instrument(
688            send_compatible_streaming_request(self.client.clone(), req),
689            span,
690        )
691        .await
692    }
693}
694
695#[derive(Deserialize, Debug)]
696pub struct StreamingDelta {
697    #[serde(default)]
698    content: Option<String>,
699    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
700    tool_calls: Vec<StreamingToolCall>,
701    reasoning_content: Option<String>,
702}
703
704#[derive(Deserialize, Debug)]
705struct StreamingChoice {
706    delta: StreamingDelta,
707}
708
709#[derive(Deserialize, Debug)]
710struct StreamingCompletionChunk {
711    id: Option<String>,
712    model: Option<String>,
713    choices: Vec<StreamingChoice>,
714    usage: Option<Usage>,
715}
716
717#[derive(Clone, Deserialize, Serialize, Debug)]
718pub struct StreamingCompletionResponse {
719    pub usage: Usage,
720}
721
722impl GetTokenUsage for StreamingCompletionResponse {
723    fn token_usage(&self) -> Option<crate::completion::Usage> {
724        self.usage.token_usage()
725    }
726}
727
728#[derive(Clone, Copy)]
729struct DeepSeekCompatibleProfile;
730
731impl CompatibleStreamProfile for DeepSeekCompatibleProfile {
732    type Usage = Usage;
733    type Detail = ();
734    type FinalResponse = StreamingCompletionResponse;
735
736    fn normalize_chunk(
737        &self,
738        data: &str,
739    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
740        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
741            Ok(data) => data,
742            Err(error) => {
743                tracing::debug!(
744                    "Couldn't parse SSE payload as StreamingCompletionChunk: {:?}",
745                    error
746                );
747                return Ok(None);
748            }
749        };
750
751        Ok(Some(
752            openai_chat_completions_compatible::normalize_first_choice_chunk(
753                data.id,
754                data.model,
755                data.usage,
756                &data.choices,
757                |choice| CompatibleChoiceData {
758                    finish_reason: CompatibleFinishReason::Other,
759                    text: choice.delta.content.clone(),
760                    reasoning: choice.delta.reasoning_content.clone(),
761                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
762                        &choice.delta.tool_calls,
763                    ),
764                    details: Vec::new(),
765                },
766            ),
767        ))
768    }
769
770    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
771        StreamingCompletionResponse { usage }
772    }
773
774    fn uses_distinct_tool_call_eviction(&self) -> bool {
775        true
776    }
777
778    fn emits_complete_single_chunk_tool_calls(&self) -> bool {
779        true
780    }
781}
782
783pub async fn send_compatible_streaming_request<T>(
784    http_client: T,
785    req: Request<Vec<u8>>,
786) -> Result<
787    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
788    CompletionError,
789>
790where
791    T: HttpClientExt + Clone + 'static,
792{
793    openai_chat_completions_compatible::send_compatible_streaming_request(
794        http_client,
795        req,
796        DeepSeekCompatibleProfile,
797    )
798    .await
799}
800
801#[derive(Debug, Deserialize)]
802struct ListModelsResponse {
803    data: Vec<ListModelEntry>,
804}
805
806#[derive(Debug, Deserialize)]
807struct ListModelEntry {
808    id: String,
809    owned_by: String,
810}
811
812impl From<ListModelEntry> for Model {
813    fn from(value: ListModelEntry) -> Self {
814        let mut model = Model::from_id(value.id);
815        model.owned_by = Some(value.owned_by);
816        model
817    }
818}
819
820/// [`ModelLister`] implementation for the DeepSeek API (`GET /models`).
821#[derive(Clone)]
822pub struct DeepSeekModelLister<H = reqwest::Client> {
823    client: Client<H>,
824}
825
826impl<H> ModelLister<H> for DeepSeekModelLister<H>
827where
828    H: HttpClientExt + WasmCompatSend + WasmCompatSync + 'static,
829{
830    type Client = Client<H>;
831
832    fn new(client: Self::Client) -> Self {
833        Self { client }
834    }
835
836    async fn list_all(&self) -> Result<ModelList, ModelListingError> {
837        let path = "/models";
838        let req = self.client.get(path)?.body(http_client::NoBody)?;
839        let response = self
840            .client
841            .send::<_, Vec<u8>>(req)
842            .await
843            .map_err(|error| match error {
844                http_client::Error::InvalidStatusCodeWithMessage(status, message) => {
845                    ModelListingError::api_error_with_context(
846                        "DeepSeek",
847                        path,
848                        status.as_u16(),
849                        message.as_bytes(),
850                    )
851                }
852                other => ModelListingError::from(other),
853            })?;
854
855        if !response.status().is_success() {
856            let status_code = response.status().as_u16();
857            let body = response.into_body().await?;
858            return Err(ModelListingError::api_error_with_context(
859                "DeepSeek",
860                path,
861                status_code,
862                &body,
863            ));
864        }
865
866        let body = response.into_body().await?;
867        let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
868            ModelListingError::parse_error_with_context("DeepSeek", path, &error, &body)
869        })?;
870
871        let models = api_resp.data.into_iter().map(Model::from).collect();
872
873        Ok(ModelList::new(models))
874    }
875}
876
877// ================================================================
878// DeepSeek Completion API
879// ================================================================
880#[deprecated(
881    note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
882    For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
883    respectively."
884)]
885pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
886#[deprecated(
887    note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
888    For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
889    respectively."
890)]
891pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
892pub const DEEPSEEK_V4_FLASH: &str = "deepseek-v4-flash";
893pub const DEEPSEEK_V4_PRO: &str = "deepseek-v4-pro";
894
895// Tests
896#[cfg(test)]
897mod tests {
898    use super::*;
899    use crate::client::ModelListingClient;
900    use crate::test_utils::RecordingHttpClient;
901
902    #[test]
903    fn test_deserialize_vec_choice() {
904        let data = r#"[{
905            "finish_reason": "stop",
906            "index": 0,
907            "logprobs": null,
908            "message":{"role":"assistant","content":"Hello, world!"}
909            }]"#;
910
911        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
912        assert_eq!(choices.len(), 1);
913        match &choices.first().unwrap().message {
914            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
915            _ => panic!("Expected assistant message"),
916        }
917    }
918
919    #[test]
920    fn test_deserialize_deepseek_response() {
921        let data = r#"{
922            "choices":[{
923                "finish_reason": "stop",
924                "index": 0,
925                "logprobs": null,
926                "message":{"role":"assistant","content":"Hello, world!"}
927            }],
928            "usage": {
929                "completion_tokens": 0,
930                "prompt_tokens": 0,
931                "prompt_cache_hit_tokens": 0,
932                "prompt_cache_miss_tokens": 0,
933                "total_tokens": 0
934            }
935        }"#;
936
937        let jd = &mut serde_json::Deserializer::from_str(data);
938        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
939        match result {
940            Ok(response) => match &response.choices.first().unwrap().message {
941                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
942                _ => panic!("Expected assistant message"),
943            },
944            Err(err) => {
945                panic!("Deserialization error at {}: {}", err.path(), err);
946            }
947        }
948    }
949
950    #[test]
951    fn test_deserialize_example_response() {
952        let data = r#"
953        {
954            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
955            "object": "chat.completion",
956            "created": 0,
957            "model": "deepseek-chat",
958            "choices": [
959                {
960                    "index": 0,
961                    "message": {
962                        "role": "assistant",
963                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
964                    },
965                    "logprobs": null,
966                    "finish_reason": "stop"
967                }
968            ],
969            "usage": {
970                "prompt_tokens": 13,
971                "completion_tokens": 32,
972                "total_tokens": 45,
973                "prompt_tokens_details": {
974                    "cached_tokens": 0
975                },
976                "prompt_cache_hit_tokens": 0,
977                "prompt_cache_miss_tokens": 13
978            },
979            "system_fingerprint": "fp_4b6881f2c5"
980        }
981        "#;
982        let jd = &mut serde_json::Deserializer::from_str(data);
983        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
984
985        match result {
986            Ok(response) => match &response.choices.first().unwrap().message {
987                Message::Assistant { content, .. } => assert_eq!(
988                    content,
989                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
990                ),
991                _ => panic!("Expected assistant message"),
992            },
993            Err(err) => {
994                panic!("Deserialization error at {}: {}", err.path(), err);
995            }
996        }
997    }
998
999    #[test]
1000    fn test_serialize_deserialize_tool_call_message() {
1001        let tool_call_choice_json = r#"
1002            {
1003              "finish_reason": "tool_calls",
1004              "index": 0,
1005              "logprobs": null,
1006              "message": {
1007                "content": "",
1008                "role": "assistant",
1009                "tool_calls": [
1010                  {
1011                    "function": {
1012                      "arguments": "{\"x\":2,\"y\":5}",
1013                      "name": "subtract"
1014                    },
1015                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1016                    "index": 0,
1017                    "type": "function"
1018                  }
1019                ]
1020              }
1021            }
1022        "#;
1023
1024        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1025
1026        let expected_choice: Choice = Choice {
1027            finish_reason: "tool_calls".to_string(),
1028            index: 0,
1029            logprobs: None,
1030            message: Message::Assistant {
1031                content: "".to_string(),
1032                name: None,
1033                tool_calls: vec![ToolCall {
1034                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1035                    function: Function {
1036                        name: "subtract".to_string(),
1037                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1038                    },
1039                    index: 0,
1040                    r#type: ToolType::Function,
1041                }],
1042                reasoning_content: None,
1043            },
1044        };
1045
1046        assert_eq!(choice, expected_choice);
1047    }
1048    #[test]
1049    fn test_user_message_multiple_text_items_merged() {
1050        use crate::completion::message::{Message as RigMessage, UserContent};
1051
1052        let rig_msg = RigMessage::User {
1053            content: OneOrMany::many(vec![
1054                UserContent::text("first part"),
1055                UserContent::text("second part"),
1056            ])
1057            .expect("content should not be empty"),
1058        };
1059
1060        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1061
1062        let user_messages: Vec<&Message> = messages
1063            .iter()
1064            .filter(|m| matches!(m, Message::User { .. }))
1065            .collect();
1066
1067        assert_eq!(
1068            user_messages.len(),
1069            1,
1070            "multiple text items should produce a single user message"
1071        );
1072        match &user_messages[0] {
1073            Message::User { content, .. } => {
1074                assert_eq!(content, "first part\nsecond part");
1075            }
1076            _ => unreachable!(),
1077        }
1078    }
1079
1080    #[test]
1081    fn test_assistant_message_with_reasoning_and_tool_calls() {
1082        use crate::completion::message::{AssistantContent, Message as RigMessage};
1083
1084        let rig_msg = RigMessage::Assistant {
1085            id: None,
1086            content: OneOrMany::many(vec![
1087                AssistantContent::reasoning("thinking about the problem"),
1088                AssistantContent::text("I'll call the tool"),
1089                AssistantContent::tool_call(
1090                    "call_1",
1091                    "subtract",
1092                    serde_json::json!({"x": 2, "y": 5}),
1093                ),
1094            ])
1095            .expect("content should not be empty"),
1096        };
1097
1098        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1099
1100        assert_eq!(messages.len(), 1, "should produce exactly one message");
1101        match &messages[0] {
1102            Message::Assistant {
1103                content,
1104                tool_calls,
1105                reasoning_content,
1106                ..
1107            } => {
1108                assert_eq!(content, "I'll call the tool");
1109                assert_eq!(
1110                    reasoning_content.as_deref(),
1111                    Some("thinking about the problem")
1112                );
1113                assert_eq!(tool_calls.len(), 1);
1114                assert_eq!(tool_calls[0].function.name, "subtract");
1115            }
1116            _ => panic!("Expected assistant message"),
1117        }
1118    }
1119
1120    #[test]
1121    fn test_assistant_message_without_reasoning() {
1122        use crate::completion::message::{AssistantContent, Message as RigMessage};
1123
1124        let rig_msg = RigMessage::Assistant {
1125            id: None,
1126            content: OneOrMany::many(vec![
1127                AssistantContent::text("calling tool"),
1128                AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1129            ])
1130            .expect("content should not be empty"),
1131        };
1132
1133        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1134
1135        assert_eq!(messages.len(), 1);
1136        match &messages[0] {
1137            Message::Assistant {
1138                reasoning_content,
1139                tool_calls,
1140                ..
1141            } => {
1142                assert!(reasoning_content.is_none());
1143                assert_eq!(tool_calls.len(), 1);
1144            }
1145            _ => panic!("Expected assistant message"),
1146        }
1147    }
1148
1149    #[test]
1150    fn test_client_initialization() {
1151        let _client =
1152            crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1153        let _client_from_builder = crate::providers::deepseek::Client::builder()
1154            .api_key("dummy-key")
1155            .build()
1156            .expect("Client::builder() failed");
1157    }
1158
1159    #[test]
1160    fn test_deserialize_list_models_response() {
1161        let data = r#"{
1162            "object": "list",
1163            "data": [
1164                {
1165                    "id": "deepseek-v4-flash",
1166                    "object": "model",
1167                    "owned_by": "deepseek"
1168                },
1169                {
1170                    "id": "deepseek-v4-pro",
1171                    "object": "model",
1172                    "owned_by": "deepseek"
1173                }
1174            ]
1175        }"#;
1176
1177        let response: ListModelsResponse = serde_json::from_str(data).unwrap();
1178
1179        assert_eq!(response.data.len(), 2);
1180        assert_eq!(response.data[0].id, "deepseek-v4-flash");
1181        assert_eq!(response.data[0].owned_by, "deepseek");
1182    }
1183
1184    #[tokio::test]
1185    async fn test_list_models_uses_models_endpoint() {
1186        let response_body = r#"{
1187            "object": "list",
1188            "data": [
1189                {
1190                    "id": "deepseek-v4-flash",
1191                    "object": "model",
1192                    "owned_by": "deepseek"
1193                },
1194                {
1195                    "id": "deepseek-v4-pro",
1196                    "object": "model",
1197                    "owned_by": "deepseek"
1198                }
1199            ]
1200        }"#;
1201
1202        let http_client = RecordingHttpClient::new(response_body);
1203        let client = Client::builder()
1204            .api_key("dummy-key")
1205            .http_client(http_client.clone())
1206            .build()
1207            .expect("client should build");
1208
1209        let models = client
1210            .list_models()
1211            .await
1212            .expect("list_models should succeed");
1213
1214        assert_eq!(models.len(), 2);
1215        assert_eq!(models.data[0].id, "deepseek-v4-flash");
1216        assert_eq!(models.data[0].r#type, None);
1217        assert_eq!(models.data[0].owned_by.as_deref(), Some("deepseek"));
1218        let requests = http_client.requests();
1219        assert_eq!(requests.len(), 1);
1220        assert_eq!(requests[0].uri, "https://api.deepseek.com/models");
1221    }
1222
1223    #[tokio::test]
1224    async fn test_list_models_preserves_api_error_context() {
1225        let http_client = RecordingHttpClient::with_error(
1226            http::StatusCode::UNAUTHORIZED,
1227            r#"{"error":{"message":"invalid api key"}}"#,
1228        );
1229        let client = Client::builder()
1230            .api_key("dummy-key")
1231            .http_client(http_client)
1232            .build()
1233            .expect("client should build");
1234
1235        let error = client
1236            .list_models()
1237            .await
1238            .expect_err("list_models should fail");
1239
1240        match error {
1241            ModelListingError::ApiError {
1242                status_code,
1243                message,
1244            } => {
1245                assert_eq!(status_code, 401);
1246                assert!(message.contains("provider=DeepSeek"));
1247                assert!(message.contains("path=/models"));
1248                assert!(message.contains("invalid api key"));
1249            }
1250            other => panic!("expected api error, got {other:?}"),
1251        }
1252    }
1253}