Skip to main content

rig_core/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::{client::CompletionClient, providers::deepseek};
6//!
7//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
8//! let client = deepseek::Client::new("DEEPSEEK_API_KEY")?;
9//!
10//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
11//! # Ok(())
12//! # }
13//! ```
14
15use bytes::Bytes;
16use http::Request;
17use tracing::{Instrument, Level, enabled, info_span};
18
19use crate::client::{
20    self, BearerAuth, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider,
21    ProviderBuilder, ProviderClient,
22};
23use crate::completion::GetTokenUsage;
24use crate::http_client::{self, HttpClientExt};
25use crate::message::{Document, DocumentSourceKind};
26use crate::model::{Model, ModelList, ModelListingError};
27use crate::providers::internal::openai_chat_completions_compatible::{
28    self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
29};
30use crate::{
31    OneOrMany,
32    completion::{self, CompletionError, CompletionRequest},
33    json_utils, message,
34    wasm_compat::{WasmCompatSend, WasmCompatSync},
35};
36use serde::{Deserialize, Serialize};
37
38use super::openai::StreamingToolCall;
39
40// ================================================================
41// Main DeepSeek Client
42// ================================================================
43const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
44
45#[derive(Debug, Default, Clone, Copy)]
46pub struct DeepSeekExt;
47#[derive(Debug, Default, Clone, Copy)]
48pub struct DeepSeekExtBuilder;
49
50type DeepSeekApiKey = BearerAuth;
51
52impl Provider for DeepSeekExt {
53    type Builder = DeepSeekExtBuilder;
54    const VERIFY_PATH: &'static str = "/user/balance";
55}
56
57impl<H> Capabilities<H> for DeepSeekExt {
58    type Completion = Capable<CompletionModel<H>>;
59    type Embeddings = Nothing;
60    type Transcription = Nothing;
61    type ModelListing = Capable<DeepSeekModelLister<H>>;
62    #[cfg(feature = "image")]
63    type ImageGeneration = Nothing;
64    #[cfg(feature = "audio")]
65    type AudioGeneration = Nothing;
66    type Rerank = Nothing;
67}
68
69impl DebugExt for DeepSeekExt {}
70
71impl ProviderBuilder for DeepSeekExtBuilder {
72    type Extension<H>
73        = DeepSeekExt
74    where
75        H: HttpClientExt;
76    type ApiKey = DeepSeekApiKey;
77
78    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
79
80    fn build<H>(
81        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
82    ) -> http_client::Result<Self::Extension<H>>
83    where
84        H: HttpClientExt,
85    {
86        Ok(DeepSeekExt)
87    }
88}
89
90pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
91pub type ClientBuilder<H = crate::markers::Missing> =
92    client::ClientBuilder<DeepSeekExtBuilder, String, H>;
93
94impl ProviderClient for Client {
95    type Input = DeepSeekApiKey;
96    type Error = crate::client::ProviderClientError;
97
98    // If you prefer the environment variable approach:
99    fn from_env() -> Result<Self, Self::Error> {
100        let api_key = crate::client::required_env_var("DEEPSEEK_API_KEY")?;
101        let mut client_builder = Self::builder();
102        client_builder.headers_mut().insert(
103            http::header::CONTENT_TYPE,
104            http::HeaderValue::from_static("application/json"),
105        );
106        let client_builder = client_builder.api_key(&api_key);
107        client_builder.build().map_err(Into::into)
108    }
109
110    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
111        Self::new(input).map_err(Into::into)
112    }
113}
114
115#[derive(Debug, Deserialize)]
116struct ApiErrorResponse {
117    message: String,
118}
119
120#[derive(Debug, Deserialize)]
121#[serde(untagged)]
122enum ApiResponse<T> {
123    Ok(T),
124    Err(ApiErrorResponse),
125}
126
127impl From<ApiErrorResponse> for CompletionError {
128    fn from(err: ApiErrorResponse) -> Self {
129        CompletionError::ProviderError(err.message)
130    }
131}
132
133/// The response shape from the DeepSeek API
134#[derive(Clone, Debug, Serialize, Deserialize)]
135pub struct CompletionResponse {
136    // We'll match the JSON:
137    pub choices: Vec<Choice>,
138    pub usage: Usage,
139    // you may want other fields
140}
141
142#[derive(Clone, Debug, Serialize, Deserialize, Default)]
143pub struct Usage {
144    pub completion_tokens: u32,
145    pub prompt_tokens: u32,
146    pub prompt_cache_hit_tokens: u32,
147    pub prompt_cache_miss_tokens: u32,
148    pub total_tokens: u32,
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub completion_tokens_details: Option<CompletionTokensDetails>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub prompt_tokens_details: Option<PromptTokensDetails>,
153}
154
155impl GetTokenUsage for Usage {
156    fn token_usage(&self) -> crate::completion::Usage {
157        crate::providers::internal::completion_usage(
158            self.prompt_tokens as u64,
159            self.completion_tokens as u64,
160            self.total_tokens as u64,
161            self.prompt_tokens_details
162                .as_ref()
163                .and_then(|details| details.cached_tokens)
164                .map(u64::from)
165                .unwrap_or(0),
166        )
167    }
168}
169
170#[derive(Clone, Debug, Serialize, Deserialize, Default)]
171pub struct CompletionTokensDetails {
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub reasoning_tokens: Option<u32>,
174}
175
176#[derive(Clone, Debug, Serialize, Deserialize, Default)]
177pub struct PromptTokensDetails {
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub cached_tokens: Option<u32>,
180}
181
182#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
183pub struct Choice {
184    pub index: usize,
185    pub message: Message,
186    pub logprobs: Option<serde_json::Value>,
187    pub finish_reason: String,
188}
189
190#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
191#[serde(tag = "role", rename_all = "lowercase")]
192pub enum Message {
193    System {
194        content: String,
195        #[serde(skip_serializing_if = "Option::is_none")]
196        name: Option<String>,
197    },
198    User {
199        content: String,
200        #[serde(skip_serializing_if = "Option::is_none")]
201        name: Option<String>,
202    },
203    Assistant {
204        content: String,
205        #[serde(skip_serializing_if = "Option::is_none")]
206        name: Option<String>,
207        #[serde(
208            default,
209            deserialize_with = "json_utils::null_or_vec",
210            skip_serializing_if = "Vec::is_empty"
211        )]
212        tool_calls: Vec<ToolCall>,
213        /// only exists on `deepseek-reasoner` model at time of addition
214        #[serde(skip_serializing_if = "Option::is_none")]
215        reasoning_content: Option<String>,
216    },
217    #[serde(rename = "tool")]
218    ToolResult {
219        tool_call_id: String,
220        content: String,
221    },
222}
223
224impl Message {
225    pub fn system(content: &str) -> Self {
226        Message::System {
227            content: content.to_owned(),
228            name: None,
229        }
230    }
231}
232
233impl From<message::ToolResult> for Message {
234    fn from(tool_result: message::ToolResult) -> Self {
235        let content = match tool_result.content.first() {
236            message::ToolResultContent::Text(text) => text.text,
237            message::ToolResultContent::Image(_) => String::from("[Image]"),
238        };
239
240        Message::ToolResult {
241            tool_call_id: tool_result.id,
242            content,
243        }
244    }
245}
246
247impl From<message::ToolCall> for ToolCall {
248    fn from(tool_call: message::ToolCall) -> Self {
249        Self {
250            id: tool_call.id,
251            // TODO: update index when we have it
252            index: 0,
253            r#type: ToolType::Function,
254            function: Function {
255                name: tool_call.function.name,
256                arguments: tool_call.function.arguments,
257            },
258        }
259    }
260}
261
262impl TryFrom<message::Message> for Vec<Message> {
263    type Error = message::MessageError;
264
265    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
266        match message {
267            message::Message::System { content } => Ok(vec![Message::System {
268                content,
269                name: None,
270            }]),
271            message::Message::User { content } => {
272                // extract tool results
273                let mut messages = vec![];
274
275                let tool_results = content
276                    .clone()
277                    .into_iter()
278                    .filter_map(|content| match content {
279                        message::UserContent::ToolResult(tool_result) => {
280                            Some(Message::from(tool_result))
281                        }
282                        _ => None,
283                    })
284                    .collect::<Vec<_>>();
285
286                messages.extend(tool_results);
287
288                let text_content: String = content
289                    .into_iter()
290                    .filter_map(|content| match content {
291                        message::UserContent::Text(text) => Some(text.text),
292                        message::UserContent::Document(Document {
293                            data:
294                                DocumentSourceKind::Base64(content)
295                                | DocumentSourceKind::String(content),
296                            ..
297                        }) => Some(content),
298                        _ => None,
299                    })
300                    .collect::<Vec<_>>()
301                    .join("\n");
302
303                if !text_content.is_empty() {
304                    messages.push(Message::User {
305                        content: text_content,
306                        name: None,
307                    });
308                }
309
310                Ok(messages)
311            }
312            message::Message::Assistant { content, .. } => {
313                let mut text_content = String::new();
314                let mut reasoning_content = String::new();
315                let mut tool_calls = Vec::new();
316
317                for item in content.iter() {
318                    match item {
319                        message::AssistantContent::Text(text) => {
320                            text_content.push_str(text.text());
321                        }
322                        message::AssistantContent::Reasoning(reasoning) => {
323                            reasoning_content.push_str(&reasoning.display_text());
324                        }
325                        message::AssistantContent::ToolCall(tool_call) => {
326                            tool_calls.push(ToolCall::from(tool_call.clone()));
327                        }
328                        _ => {}
329                    }
330                }
331
332                let reasoning = if reasoning_content.is_empty() {
333                    None
334                } else {
335                    Some(reasoning_content)
336                };
337
338                Ok(vec![Message::Assistant {
339                    content: text_content,
340                    name: None,
341                    tool_calls,
342                    reasoning_content: reasoning,
343                }])
344            }
345        }
346    }
347}
348
349#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
350pub struct ToolCall {
351    pub id: String,
352    pub index: usize,
353    #[serde(default)]
354    pub r#type: ToolType,
355    pub function: Function,
356}
357
358#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
359pub struct Function {
360    pub name: String,
361    #[serde(with = "json_utils::stringified_json")]
362    pub arguments: serde_json::Value,
363}
364
365#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
366#[serde(rename_all = "lowercase")]
367pub enum ToolType {
368    #[default]
369    Function,
370}
371
372#[derive(Clone, Debug, Deserialize, Serialize)]
373pub struct ToolDefinition {
374    pub r#type: String,
375    pub function: completion::ToolDefinition,
376}
377
378impl From<crate::completion::ToolDefinition> for ToolDefinition {
379    fn from(tool: crate::completion::ToolDefinition) -> Self {
380        Self {
381            r#type: "function".into(),
382            function: tool,
383        }
384    }
385}
386
387impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
388    type Error = CompletionError;
389
390    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
391        let choice = response.choices.first().ok_or_else(|| {
392            CompletionError::ResponseError("Response contained no choices".to_owned())
393        })?;
394        let content = match &choice.message {
395            Message::Assistant {
396                content,
397                tool_calls,
398                reasoning_content,
399                ..
400            } => {
401                let mut content = if content.trim().is_empty() {
402                    vec![]
403                } else {
404                    vec![completion::AssistantContent::text(content)]
405                };
406
407                content.extend(
408                    tool_calls
409                        .iter()
410                        .map(|call| {
411                            completion::AssistantContent::tool_call(
412                                &call.id,
413                                &call.function.name,
414                                call.function.arguments.clone(),
415                            )
416                        })
417                        .collect::<Vec<_>>(),
418                );
419
420                if let Some(reasoning_content) = reasoning_content {
421                    content.push(completion::AssistantContent::reasoning(reasoning_content));
422                }
423
424                Ok(content)
425            }
426            _ => Err(CompletionError::ResponseError(
427                "Response did not contain a valid message or tool call".into(),
428            )),
429        }?;
430
431        let choice = OneOrMany::many(content).map_err(|_| {
432            CompletionError::ResponseError(
433                "Response contained no message or tool call (empty)".to_owned(),
434            )
435        })?;
436
437        let usage = completion::Usage {
438            input_tokens: response.usage.prompt_tokens as u64,
439            output_tokens: response.usage.completion_tokens as u64,
440            total_tokens: response.usage.total_tokens as u64,
441            cached_input_tokens: response
442                .usage
443                .prompt_tokens_details
444                .as_ref()
445                .and_then(|d| d.cached_tokens)
446                .map(|c| c as u64)
447                .unwrap_or(0),
448            cache_creation_input_tokens: 0,
449            tool_use_prompt_tokens: 0,
450            reasoning_tokens: 0,
451        };
452
453        Ok(completion::CompletionResponse {
454            choice,
455            usage,
456            raw_response: response,
457            message_id: None,
458        })
459    }
460}
461
462#[derive(Debug, Serialize, Deserialize)]
463pub(super) struct DeepseekCompletionRequest {
464    model: String,
465    pub messages: Vec<Message>,
466    #[serde(skip_serializing_if = "Option::is_none")]
467    temperature: Option<f64>,
468    #[serde(skip_serializing_if = "Vec::is_empty")]
469    tools: Vec<ToolDefinition>,
470    #[serde(skip_serializing_if = "Option::is_none")]
471    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
472    #[serde(flatten, skip_serializing_if = "Option::is_none")]
473    pub additional_params: Option<serde_json::Value>,
474}
475
476impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
477    type Error = CompletionError;
478
479    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
480        let chat_history = req.chat_history_with_documents();
481        if req.output_schema.is_some() {
482            tracing::warn!("Structured outputs currently not supported for DeepSeek");
483        }
484        let model = req.model.clone().unwrap_or_else(|| model.to_string());
485        let mut full_history: Vec<Message> = match &req.preamble {
486            Some(preamble) => vec![Message::system(preamble)],
487            None => vec![],
488        };
489
490        let chat_history: Vec<Message> = chat_history
491            .into_iter()
492            .map(|message| message.try_into())
493            .collect::<Result<Vec<Vec<Message>>, _>>()?
494            .into_iter()
495            .flatten()
496            .collect();
497
498        full_history.extend(chat_history);
499
500        let tool_choice = req
501            .tool_choice
502            .clone()
503            .map(crate::providers::openrouter::ToolChoice::try_from)
504            .transpose()?;
505
506        Ok(Self {
507            model: model.to_string(),
508            messages: full_history,
509            temperature: req.temperature,
510            tools: req
511                .tools
512                .clone()
513                .into_iter()
514                .map(ToolDefinition::from)
515                .collect::<Vec<_>>(),
516            tool_choice,
517            additional_params: req.additional_params,
518        })
519    }
520}
521
522/// The struct implementing the `CompletionModel` trait
523#[derive(Clone)]
524pub struct CompletionModel<T = reqwest::Client> {
525    pub client: Client<T>,
526    pub model: String,
527}
528
529impl<T> completion::CompletionModel for CompletionModel<T>
530where
531    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
532{
533    type Response = CompletionResponse;
534    type StreamingResponse = StreamingCompletionResponse;
535
536    type Client = Client<T>;
537
538    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
539        Self {
540            client: client.clone(),
541            model: model.into().to_string(),
542        }
543    }
544
545    async fn completion(
546        &self,
547        completion_request: CompletionRequest,
548    ) -> Result<
549        completion::CompletionResponse<CompletionResponse>,
550        crate::completion::CompletionError,
551    > {
552        let span = if tracing::Span::current().is_disabled() {
553            info_span!(
554                target: "rig::completions",
555                "chat",
556                gen_ai.operation.name = "chat",
557                gen_ai.provider.name = "deepseek",
558                gen_ai.request.model = self.model,
559                gen_ai.system_instructions = tracing::field::Empty,
560                gen_ai.response.id = tracing::field::Empty,
561                gen_ai.response.model = tracing::field::Empty,
562                gen_ai.usage.output_tokens = tracing::field::Empty,
563                gen_ai.usage.input_tokens = tracing::field::Empty,
564                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
565            )
566        } else {
567            tracing::Span::current()
568        };
569
570        span.record("gen_ai.system_instructions", &completion_request.preamble);
571
572        let request =
573            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
574
575        if enabled!(Level::TRACE) {
576            tracing::trace!(target: "rig::completions",
577                "DeepSeek completion request: {}",
578                serde_json::to_string_pretty(&request)?
579            );
580        }
581
582        let body = serde_json::to_vec(&request)?;
583        let req = self
584            .client
585            .post("/chat/completions")?
586            .body(body)
587            .map_err(|e| CompletionError::HttpError(e.into()))?;
588
589        async move {
590            let response = self.client.send::<_, Bytes>(req).await?;
591            let status = response.status();
592            let response_body = response.into_body().into_future().await?.to_vec();
593
594            if status.is_success() {
595                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
596                    ApiResponse::Ok(response) => {
597                        let span = tracing::Span::current();
598                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
599                        span.record(
600                            "gen_ai.usage.output_tokens",
601                            response.usage.completion_tokens,
602                        );
603                        span.record(
604                            "gen_ai.usage.cache_read.input_tokens",
605                            response
606                                .usage
607                                .prompt_tokens_details
608                                .as_ref()
609                                .and_then(|d| d.cached_tokens)
610                                .unwrap_or(0),
611                        );
612                        if enabled!(Level::TRACE) {
613                            tracing::trace!(target: "rig::completions",
614                                "DeepSeek completion response: {}",
615                                serde_json::to_string_pretty(&response)?
616                            );
617                        }
618                        response.try_into()
619                    }
620                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
621                }
622            } else {
623                Err(CompletionError::ProviderError(
624                    String::from_utf8_lossy(&response_body).to_string(),
625                ))
626            }
627        }
628        .instrument(span)
629        .await
630    }
631
632    async fn stream(
633        &self,
634        completion_request: CompletionRequest,
635    ) -> Result<
636        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
637        CompletionError,
638    > {
639        let preamble = completion_request.preamble.clone();
640        let mut request =
641            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
642
643        let params = json_utils::merge(
644            request.additional_params.unwrap_or(serde_json::json!({})),
645            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
646        );
647
648        request.additional_params = Some(params);
649
650        if enabled!(Level::TRACE) {
651            tracing::trace!(target: "rig::completions",
652                "DeepSeek streaming completion request: {}",
653                serde_json::to_string_pretty(&request)?
654            );
655        }
656
657        let body = serde_json::to_vec(&request)?;
658
659        let req = self
660            .client
661            .post("/chat/completions")?
662            .body(body)
663            .map_err(|e| CompletionError::HttpError(e.into()))?;
664
665        let span = if tracing::Span::current().is_disabled() {
666            info_span!(
667                target: "rig::completions",
668                "chat_streaming",
669                gen_ai.operation.name = "chat_streaming",
670                gen_ai.provider.name = "deepseek",
671                gen_ai.request.model = self.model,
672                gen_ai.system_instructions = preamble,
673                gen_ai.response.id = tracing::field::Empty,
674                gen_ai.response.model = tracing::field::Empty,
675                gen_ai.usage.output_tokens = tracing::field::Empty,
676                gen_ai.usage.input_tokens = tracing::field::Empty,
677                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
678            )
679        } else {
680            tracing::Span::current()
681        };
682
683        tracing::Instrument::instrument(
684            send_compatible_streaming_request(self.client.clone(), req),
685            span,
686        )
687        .await
688    }
689}
690
691#[derive(Deserialize, Debug)]
692pub struct StreamingDelta {
693    #[serde(default)]
694    content: Option<String>,
695    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
696    tool_calls: Vec<StreamingToolCall>,
697    reasoning_content: Option<String>,
698}
699
700#[derive(Deserialize, Debug)]
701struct StreamingChoice {
702    delta: StreamingDelta,
703}
704
705#[derive(Deserialize, Debug)]
706struct StreamingCompletionChunk {
707    id: Option<String>,
708    model: Option<String>,
709    choices: Vec<StreamingChoice>,
710    usage: Option<Usage>,
711}
712
713#[derive(Clone, Deserialize, Serialize, Debug)]
714pub struct StreamingCompletionResponse {
715    pub usage: Usage,
716}
717
718impl GetTokenUsage for StreamingCompletionResponse {
719    fn token_usage(&self) -> crate::completion::Usage {
720        self.usage.token_usage()
721    }
722}
723
724#[derive(Clone, Copy)]
725struct DeepSeekCompatibleProfile;
726
727impl CompatibleStreamProfile for DeepSeekCompatibleProfile {
728    type Usage = Usage;
729    type Detail = ();
730    type FinalResponse = StreamingCompletionResponse;
731
732    fn normalize_chunk(
733        &self,
734        data: &str,
735    ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
736        let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
737            Ok(data) => data,
738            Err(error) => {
739                tracing::debug!(
740                    "Couldn't parse SSE payload as StreamingCompletionChunk: {:?}",
741                    error
742                );
743                return Ok(None);
744            }
745        };
746
747        Ok(Some(
748            openai_chat_completions_compatible::normalize_first_choice_chunk(
749                data.id,
750                data.model,
751                data.usage,
752                &data.choices,
753                |choice| CompatibleChoiceData {
754                    finish_reason: CompatibleFinishReason::Other,
755                    text: choice.delta.content.clone(),
756                    reasoning: choice.delta.reasoning_content.clone(),
757                    tool_calls: openai_chat_completions_compatible::tool_call_chunks(
758                        &choice.delta.tool_calls,
759                    ),
760                    details: Vec::new(),
761                },
762            ),
763        ))
764    }
765
766    fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
767        StreamingCompletionResponse { usage }
768    }
769
770    fn uses_distinct_tool_call_eviction(&self) -> bool {
771        true
772    }
773
774    fn emits_complete_single_chunk_tool_calls(&self) -> bool {
775        true
776    }
777}
778
779pub async fn send_compatible_streaming_request<T>(
780    http_client: T,
781    req: Request<Vec<u8>>,
782) -> Result<
783    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
784    CompletionError,
785>
786where
787    T: HttpClientExt + Clone + 'static,
788{
789    openai_chat_completions_compatible::send_compatible_streaming_request(
790        http_client,
791        req,
792        DeepSeekCompatibleProfile,
793    )
794    .await
795}
796
797#[derive(Debug, Deserialize)]
798struct ListModelsResponse {
799    data: Vec<ListModelEntry>,
800}
801
802#[derive(Debug, Deserialize)]
803struct ListModelEntry {
804    id: String,
805    owned_by: String,
806}
807
808impl From<ListModelEntry> for Model {
809    fn from(value: ListModelEntry) -> Self {
810        let mut model = Model::from_id(value.id);
811        model.owned_by = Some(value.owned_by);
812        model
813    }
814}
815
816/// [`ModelLister`] implementation for the DeepSeek API (`GET /models`).
817#[derive(Clone)]
818pub struct DeepSeekModelLister<H = reqwest::Client> {
819    client: Client<H>,
820}
821
822impl<H> ModelLister<H> for DeepSeekModelLister<H>
823where
824    H: HttpClientExt + WasmCompatSend + WasmCompatSync + 'static,
825{
826    type Client = Client<H>;
827
828    fn new(client: Self::Client) -> Self {
829        Self { client }
830    }
831
832    async fn list_all(&self) -> Result<ModelList, ModelListingError> {
833        let path = "/models";
834        let req = self.client.get(path)?.body(http_client::NoBody)?;
835        let response = self
836            .client
837            .send::<_, Vec<u8>>(req)
838            .await
839            .map_err(|error| match error {
840                http_client::Error::InvalidStatusCodeWithMessage(status, message) => {
841                    ModelListingError::api_error_with_context(
842                        "DeepSeek",
843                        path,
844                        status.as_u16(),
845                        message.as_bytes(),
846                    )
847                }
848                other => ModelListingError::from(other),
849            })?;
850
851        if !response.status().is_success() {
852            let status_code = response.status().as_u16();
853            let body = response.into_body().await?;
854            return Err(ModelListingError::api_error_with_context(
855                "DeepSeek",
856                path,
857                status_code,
858                &body,
859            ));
860        }
861
862        let body = response.into_body().await?;
863        let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
864            ModelListingError::parse_error_with_context("DeepSeek", path, &error, &body)
865        })?;
866
867        let models = api_resp.data.into_iter().map(Model::from).collect();
868
869        Ok(ModelList::new(models))
870    }
871}
872
873// ================================================================
874// DeepSeek Completion API
875// ================================================================
876#[deprecated(
877    note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
878    For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
879    respectively."
880)]
881pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
882#[deprecated(
883    note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
884    For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
885    respectively."
886)]
887pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
888pub const DEEPSEEK_V4_FLASH: &str = "deepseek-v4-flash";
889pub const DEEPSEEK_V4_PRO: &str = "deepseek-v4-pro";
890
891// Tests
892#[cfg(test)]
893mod tests {
894    use super::*;
895    use crate::client::ModelListingClient;
896    use crate::test_utils::RecordingHttpClient;
897
898    #[test]
899    fn test_deserialize_vec_choice() {
900        let data = r#"[{
901            "finish_reason": "stop",
902            "index": 0,
903            "logprobs": null,
904            "message":{"role":"assistant","content":"Hello, world!"}
905            }]"#;
906
907        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
908        assert_eq!(choices.len(), 1);
909        match &choices.first().unwrap().message {
910            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
911            _ => panic!("Expected assistant message"),
912        }
913    }
914
915    #[test]
916    fn test_deserialize_deepseek_response() {
917        let data = r#"{
918            "choices":[{
919                "finish_reason": "stop",
920                "index": 0,
921                "logprobs": null,
922                "message":{"role":"assistant","content":"Hello, world!"}
923            }],
924            "usage": {
925                "completion_tokens": 0,
926                "prompt_tokens": 0,
927                "prompt_cache_hit_tokens": 0,
928                "prompt_cache_miss_tokens": 0,
929                "total_tokens": 0
930            }
931        }"#;
932
933        let jd = &mut serde_json::Deserializer::from_str(data);
934        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
935        match result {
936            Ok(response) => match &response.choices.first().unwrap().message {
937                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
938                _ => panic!("Expected assistant message"),
939            },
940            Err(err) => {
941                panic!("Deserialization error at {}: {}", err.path(), err);
942            }
943        }
944    }
945
946    #[test]
947    fn test_deserialize_example_response() {
948        let data = r#"
949        {
950            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
951            "object": "chat.completion",
952            "created": 0,
953            "model": "deepseek-chat",
954            "choices": [
955                {
956                    "index": 0,
957                    "message": {
958                        "role": "assistant",
959                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
960                    },
961                    "logprobs": null,
962                    "finish_reason": "stop"
963                }
964            ],
965            "usage": {
966                "prompt_tokens": 13,
967                "completion_tokens": 32,
968                "total_tokens": 45,
969                "prompt_tokens_details": {
970                    "cached_tokens": 0
971                },
972                "prompt_cache_hit_tokens": 0,
973                "prompt_cache_miss_tokens": 13
974            },
975            "system_fingerprint": "fp_4b6881f2c5"
976        }
977        "#;
978        let jd = &mut serde_json::Deserializer::from_str(data);
979        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
980
981        match result {
982            Ok(response) => match &response.choices.first().unwrap().message {
983                Message::Assistant { content, .. } => assert_eq!(
984                    content,
985                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
986                ),
987                _ => panic!("Expected assistant message"),
988            },
989            Err(err) => {
990                panic!("Deserialization error at {}: {}", err.path(), err);
991            }
992        }
993    }
994
995    #[test]
996    fn test_serialize_deserialize_tool_call_message() {
997        let tool_call_choice_json = r#"
998            {
999              "finish_reason": "tool_calls",
1000              "index": 0,
1001              "logprobs": null,
1002              "message": {
1003                "content": "",
1004                "role": "assistant",
1005                "tool_calls": [
1006                  {
1007                    "function": {
1008                      "arguments": "{\"x\":2,\"y\":5}",
1009                      "name": "subtract"
1010                    },
1011                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1012                    "index": 0,
1013                    "type": "function"
1014                  }
1015                ]
1016              }
1017            }
1018        "#;
1019
1020        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1021
1022        let expected_choice: Choice = Choice {
1023            finish_reason: "tool_calls".to_string(),
1024            index: 0,
1025            logprobs: None,
1026            message: Message::Assistant {
1027                content: "".to_string(),
1028                name: None,
1029                tool_calls: vec![ToolCall {
1030                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1031                    function: Function {
1032                        name: "subtract".to_string(),
1033                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1034                    },
1035                    index: 0,
1036                    r#type: ToolType::Function,
1037                }],
1038                reasoning_content: None,
1039            },
1040        };
1041
1042        assert_eq!(choice, expected_choice);
1043    }
1044    #[test]
1045    fn test_user_message_multiple_text_items_merged() {
1046        use crate::completion::message::{Message as RigMessage, UserContent};
1047
1048        let rig_msg = RigMessage::User {
1049            content: OneOrMany::many(vec![
1050                UserContent::text("first part"),
1051                UserContent::text("second part"),
1052            ])
1053            .expect("content should not be empty"),
1054        };
1055
1056        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1057
1058        let user_messages: Vec<&Message> = messages
1059            .iter()
1060            .filter(|m| matches!(m, Message::User { .. }))
1061            .collect();
1062
1063        assert_eq!(
1064            user_messages.len(),
1065            1,
1066            "multiple text items should produce a single user message"
1067        );
1068        match &user_messages[0] {
1069            Message::User { content, .. } => {
1070                assert_eq!(content, "first part\nsecond part");
1071            }
1072            _ => unreachable!(),
1073        }
1074    }
1075
1076    #[test]
1077    fn test_assistant_message_with_reasoning_and_tool_calls() {
1078        use crate::completion::message::{AssistantContent, Message as RigMessage};
1079
1080        let rig_msg = RigMessage::Assistant {
1081            id: None,
1082            content: OneOrMany::many(vec![
1083                AssistantContent::reasoning("thinking about the problem"),
1084                AssistantContent::text("I'll call the tool"),
1085                AssistantContent::tool_call(
1086                    "call_1",
1087                    "subtract",
1088                    serde_json::json!({"x": 2, "y": 5}),
1089                ),
1090            ])
1091            .expect("content should not be empty"),
1092        };
1093
1094        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1095
1096        assert_eq!(messages.len(), 1, "should produce exactly one message");
1097        match &messages[0] {
1098            Message::Assistant {
1099                content,
1100                tool_calls,
1101                reasoning_content,
1102                ..
1103            } => {
1104                assert_eq!(content, "I'll call the tool");
1105                assert_eq!(
1106                    reasoning_content.as_deref(),
1107                    Some("thinking about the problem")
1108                );
1109                assert_eq!(tool_calls.len(), 1);
1110                assert_eq!(tool_calls[0].function.name, "subtract");
1111            }
1112            _ => panic!("Expected assistant message"),
1113        }
1114    }
1115
1116    #[test]
1117    fn test_assistant_message_without_reasoning() {
1118        use crate::completion::message::{AssistantContent, Message as RigMessage};
1119
1120        let rig_msg = RigMessage::Assistant {
1121            id: None,
1122            content: OneOrMany::many(vec![
1123                AssistantContent::text("calling tool"),
1124                AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1125            ])
1126            .expect("content should not be empty"),
1127        };
1128
1129        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1130
1131        assert_eq!(messages.len(), 1);
1132        match &messages[0] {
1133            Message::Assistant {
1134                reasoning_content,
1135                tool_calls,
1136                ..
1137            } => {
1138                assert!(reasoning_content.is_none());
1139                assert_eq!(tool_calls.len(), 1);
1140            }
1141            _ => panic!("Expected assistant message"),
1142        }
1143    }
1144
1145    #[test]
1146    fn test_client_initialization() {
1147        let _client =
1148            crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1149        let _client_from_builder = crate::providers::deepseek::Client::builder()
1150            .api_key("dummy-key")
1151            .build()
1152            .expect("Client::builder() failed");
1153    }
1154
1155    #[test]
1156    fn test_deserialize_list_models_response() {
1157        let data = r#"{
1158            "object": "list",
1159            "data": [
1160                {
1161                    "id": "deepseek-v4-flash",
1162                    "object": "model",
1163                    "owned_by": "deepseek"
1164                },
1165                {
1166                    "id": "deepseek-v4-pro",
1167                    "object": "model",
1168                    "owned_by": "deepseek"
1169                }
1170            ]
1171        }"#;
1172
1173        let response: ListModelsResponse = serde_json::from_str(data).unwrap();
1174
1175        assert_eq!(response.data.len(), 2);
1176        assert_eq!(response.data[0].id, "deepseek-v4-flash");
1177        assert_eq!(response.data[0].owned_by, "deepseek");
1178    }
1179
1180    #[tokio::test]
1181    async fn test_list_models_uses_models_endpoint() {
1182        let response_body = r#"{
1183            "object": "list",
1184            "data": [
1185                {
1186                    "id": "deepseek-v4-flash",
1187                    "object": "model",
1188                    "owned_by": "deepseek"
1189                },
1190                {
1191                    "id": "deepseek-v4-pro",
1192                    "object": "model",
1193                    "owned_by": "deepseek"
1194                }
1195            ]
1196        }"#;
1197
1198        let http_client = RecordingHttpClient::new(response_body);
1199        let client = Client::builder()
1200            .api_key("dummy-key")
1201            .http_client(http_client.clone())
1202            .build()
1203            .expect("client should build");
1204
1205        let models = client
1206            .list_models()
1207            .await
1208            .expect("list_models should succeed");
1209
1210        assert_eq!(models.len(), 2);
1211        assert_eq!(models.data[0].id, "deepseek-v4-flash");
1212        assert_eq!(models.data[0].r#type, None);
1213        assert_eq!(models.data[0].owned_by.as_deref(), Some("deepseek"));
1214        let requests = http_client.requests();
1215        assert_eq!(requests.len(), 1);
1216        assert_eq!(requests[0].uri, "https://api.deepseek.com/models");
1217    }
1218
1219    #[tokio::test]
1220    async fn test_list_models_preserves_api_error_context() {
1221        let http_client = RecordingHttpClient::with_error(
1222            http::StatusCode::UNAUTHORIZED,
1223            r#"{"error":{"message":"invalid api key"}}"#,
1224        );
1225        let client = Client::builder()
1226            .api_key("dummy-key")
1227            .http_client(http_client)
1228            .build()
1229            .expect("client should build");
1230
1231        let error = client
1232            .list_models()
1233            .await
1234            .expect_err("list_models should fail");
1235
1236        match error {
1237            ModelListingError::ApiError {
1238                status_code,
1239                message,
1240            } => {
1241                assert_eq!(status_code, 401);
1242                assert!(message.contains("provider=DeepSeek"));
1243                assert!(message.contains("path=/models"));
1244                assert!(message.contains("invalid api key"));
1245            }
1246            other => panic!("expected api error, got {other:?}"),
1247        }
1248    }
1249}