1use bytes::Bytes;
13use http::Request;
14use tracing::{Instrument, Level, enabled, info_span};
15
16use crate::client::{
17 self, BearerAuth, Capabilities, Capable, DebugExt, ModelLister, Nothing, Provider,
18 ProviderBuilder, ProviderClient,
19};
20use crate::completion::GetTokenUsage;
21use crate::http_client::{self, HttpClientExt};
22use crate::message::{Document, DocumentSourceKind};
23use crate::model::{Model, ModelList, ModelListingError};
24use crate::providers::internal::openai_chat_completions_compatible::{
25 self, CompatibleChoiceData, CompatibleChunk, CompatibleFinishReason, CompatibleStreamProfile,
26};
27use crate::{
28 OneOrMany,
29 completion::{self, CompletionError, CompletionRequest},
30 json_utils, message,
31 wasm_compat::{WasmCompatSend, WasmCompatSync},
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51 const VERIFY_PATH: &'static str = "/user/balance";
52}
53
54impl<H> Capabilities<H> for DeepSeekExt {
55 type Completion = Capable<CompletionModel<H>>;
56 type Embeddings = Nothing;
57 type Transcription = Nothing;
58 type ModelListing = Capable<DeepSeekModelLister<H>>;
59 #[cfg(feature = "image")]
60 type ImageGeneration = Nothing;
61 #[cfg(feature = "audio")]
62 type AudioGeneration = Nothing;
63}
64
65impl DebugExt for DeepSeekExt {}
66
67impl ProviderBuilder for DeepSeekExtBuilder {
68 type Extension<H>
69 = DeepSeekExt
70 where
71 H: HttpClientExt;
72 type ApiKey = DeepSeekApiKey;
73
74 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
75
76 fn build<H>(
77 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
78 ) -> http_client::Result<Self::Extension<H>>
79 where
80 H: HttpClientExt,
81 {
82 Ok(DeepSeekExt)
83 }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
87pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
88
89impl ProviderClient for Client {
90 type Input = DeepSeekApiKey;
91 type Error = crate::client::ProviderClientError;
92
93 fn from_env() -> Result<Self, Self::Error> {
95 let api_key = crate::client::required_env_var("DEEPSEEK_API_KEY")?;
96 let mut client_builder = Self::builder();
97 client_builder.headers_mut().insert(
98 http::header::CONTENT_TYPE,
99 http::HeaderValue::from_static("application/json"),
100 );
101 let client_builder = client_builder.api_key(&api_key);
102 client_builder.build().map_err(Into::into)
103 }
104
105 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
106 Self::new(input).map_err(Into::into)
107 }
108}
109
110#[derive(Debug, Deserialize)]
111struct ApiErrorResponse {
112 message: String,
113}
114
115#[derive(Debug, Deserialize)]
116#[serde(untagged)]
117enum ApiResponse<T> {
118 Ok(T),
119 Err(ApiErrorResponse),
120}
121
122impl From<ApiErrorResponse> for CompletionError {
123 fn from(err: ApiErrorResponse) -> Self {
124 CompletionError::ProviderError(err.message)
125 }
126}
127
128#[derive(Clone, Debug, Serialize, Deserialize)]
130pub struct CompletionResponse {
131 pub choices: Vec<Choice>,
133 pub usage: Usage,
134 }
136
137#[derive(Clone, Debug, Serialize, Deserialize, Default)]
138pub struct Usage {
139 pub completion_tokens: u32,
140 pub prompt_tokens: u32,
141 pub prompt_cache_hit_tokens: u32,
142 pub prompt_cache_miss_tokens: u32,
143 pub total_tokens: u32,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 pub completion_tokens_details: Option<CompletionTokensDetails>,
146 #[serde(skip_serializing_if = "Option::is_none")]
147 pub prompt_tokens_details: Option<PromptTokensDetails>,
148}
149
150impl GetTokenUsage for Usage {
151 fn token_usage(&self) -> Option<crate::completion::Usage> {
152 Some(crate::providers::internal::completion_usage(
153 self.prompt_tokens as u64,
154 self.completion_tokens as u64,
155 self.total_tokens as u64,
156 self.prompt_tokens_details
157 .as_ref()
158 .and_then(|details| details.cached_tokens)
159 .map(u64::from)
160 .unwrap_or(0),
161 ))
162 }
163}
164
165#[derive(Clone, Debug, Serialize, Deserialize, Default)]
166pub struct CompletionTokensDetails {
167 #[serde(skip_serializing_if = "Option::is_none")]
168 pub reasoning_tokens: Option<u32>,
169}
170
171#[derive(Clone, Debug, Serialize, Deserialize, Default)]
172pub struct PromptTokensDetails {
173 #[serde(skip_serializing_if = "Option::is_none")]
174 pub cached_tokens: Option<u32>,
175}
176
177#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
178pub struct Choice {
179 pub index: usize,
180 pub message: Message,
181 pub logprobs: Option<serde_json::Value>,
182 pub finish_reason: String,
183}
184
185#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
186#[serde(tag = "role", rename_all = "lowercase")]
187pub enum Message {
188 System {
189 content: String,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 name: Option<String>,
192 },
193 User {
194 content: String,
195 #[serde(skip_serializing_if = "Option::is_none")]
196 name: Option<String>,
197 },
198 Assistant {
199 content: String,
200 #[serde(skip_serializing_if = "Option::is_none")]
201 name: Option<String>,
202 #[serde(
203 default,
204 deserialize_with = "json_utils::null_or_vec",
205 skip_serializing_if = "Vec::is_empty"
206 )]
207 tool_calls: Vec<ToolCall>,
208 #[serde(skip_serializing_if = "Option::is_none")]
210 reasoning_content: Option<String>,
211 },
212 #[serde(rename = "tool")]
213 ToolResult {
214 tool_call_id: String,
215 content: String,
216 },
217}
218
219impl Message {
220 pub fn system(content: &str) -> Self {
221 Message::System {
222 content: content.to_owned(),
223 name: None,
224 }
225 }
226}
227
228impl From<message::ToolResult> for Message {
229 fn from(tool_result: message::ToolResult) -> Self {
230 let content = match tool_result.content.first() {
231 message::ToolResultContent::Text(text) => text.text,
232 message::ToolResultContent::Image(_) => String::from("[Image]"),
233 };
234
235 Message::ToolResult {
236 tool_call_id: tool_result.id,
237 content,
238 }
239 }
240}
241
242impl From<message::ToolCall> for ToolCall {
243 fn from(tool_call: message::ToolCall) -> Self {
244 Self {
245 id: tool_call.id,
246 index: 0,
248 r#type: ToolType::Function,
249 function: Function {
250 name: tool_call.function.name,
251 arguments: tool_call.function.arguments,
252 },
253 }
254 }
255}
256
257impl TryFrom<message::Message> for Vec<Message> {
258 type Error = message::MessageError;
259
260 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
261 match message {
262 message::Message::System { content } => Ok(vec![Message::System {
263 content,
264 name: None,
265 }]),
266 message::Message::User { content } => {
267 let mut messages = vec![];
269
270 let tool_results = content
271 .clone()
272 .into_iter()
273 .filter_map(|content| match content {
274 message::UserContent::ToolResult(tool_result) => {
275 Some(Message::from(tool_result))
276 }
277 _ => None,
278 })
279 .collect::<Vec<_>>();
280
281 messages.extend(tool_results);
282
283 let text_content: String = content
284 .into_iter()
285 .filter_map(|content| match content {
286 message::UserContent::Text(text) => Some(text.text),
287 message::UserContent::Document(Document {
288 data:
289 DocumentSourceKind::Base64(content)
290 | DocumentSourceKind::String(content),
291 ..
292 }) => Some(content),
293 _ => None,
294 })
295 .collect::<Vec<_>>()
296 .join("\n");
297
298 if !text_content.is_empty() {
299 messages.push(Message::User {
300 content: text_content,
301 name: None,
302 });
303 }
304
305 Ok(messages)
306 }
307 message::Message::Assistant { content, .. } => {
308 let mut text_content = String::new();
309 let mut reasoning_content = String::new();
310 let mut tool_calls = Vec::new();
311
312 for item in content.iter() {
313 match item {
314 message::AssistantContent::Text(text) => {
315 text_content.push_str(text.text());
316 }
317 message::AssistantContent::Reasoning(reasoning) => {
318 reasoning_content.push_str(&reasoning.display_text());
319 }
320 message::AssistantContent::ToolCall(tool_call) => {
321 tool_calls.push(ToolCall::from(tool_call.clone()));
322 }
323 _ => {}
324 }
325 }
326
327 let reasoning = if reasoning_content.is_empty() {
328 None
329 } else {
330 Some(reasoning_content)
331 };
332
333 Ok(vec![Message::Assistant {
334 content: text_content,
335 name: None,
336 tool_calls,
337 reasoning_content: reasoning,
338 }])
339 }
340 }
341 }
342}
343
344#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
345pub struct ToolCall {
346 pub id: String,
347 pub index: usize,
348 #[serde(default)]
349 pub r#type: ToolType,
350 pub function: Function,
351}
352
353#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
354pub struct Function {
355 pub name: String,
356 #[serde(with = "json_utils::stringified_json")]
357 pub arguments: serde_json::Value,
358}
359
360#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
361#[serde(rename_all = "lowercase")]
362pub enum ToolType {
363 #[default]
364 Function,
365}
366
367#[derive(Clone, Debug, Deserialize, Serialize)]
368pub struct ToolDefinition {
369 pub r#type: String,
370 pub function: completion::ToolDefinition,
371}
372
373impl From<crate::completion::ToolDefinition> for ToolDefinition {
374 fn from(tool: crate::completion::ToolDefinition) -> Self {
375 Self {
376 r#type: "function".into(),
377 function: tool,
378 }
379 }
380}
381
382impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
383 type Error = CompletionError;
384
385 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
386 let choice = response.choices.first().ok_or_else(|| {
387 CompletionError::ResponseError("Response contained no choices".to_owned())
388 })?;
389 let content = match &choice.message {
390 Message::Assistant {
391 content,
392 tool_calls,
393 reasoning_content,
394 ..
395 } => {
396 let mut content = if content.trim().is_empty() {
397 vec![]
398 } else {
399 vec![completion::AssistantContent::text(content)]
400 };
401
402 content.extend(
403 tool_calls
404 .iter()
405 .map(|call| {
406 completion::AssistantContent::tool_call(
407 &call.id,
408 &call.function.name,
409 call.function.arguments.clone(),
410 )
411 })
412 .collect::<Vec<_>>(),
413 );
414
415 if let Some(reasoning_content) = reasoning_content {
416 content.push(completion::AssistantContent::reasoning(reasoning_content));
417 }
418
419 Ok(content)
420 }
421 _ => Err(CompletionError::ResponseError(
422 "Response did not contain a valid message or tool call".into(),
423 )),
424 }?;
425
426 let choice = OneOrMany::many(content).map_err(|_| {
427 CompletionError::ResponseError(
428 "Response contained no message or tool call (empty)".to_owned(),
429 )
430 })?;
431
432 let usage = completion::Usage {
433 input_tokens: response.usage.prompt_tokens as u64,
434 output_tokens: response.usage.completion_tokens as u64,
435 total_tokens: response.usage.total_tokens as u64,
436 cached_input_tokens: response
437 .usage
438 .prompt_tokens_details
439 .as_ref()
440 .and_then(|d| d.cached_tokens)
441 .map(|c| c as u64)
442 .unwrap_or(0),
443 cache_creation_input_tokens: 0,
444 };
445
446 Ok(completion::CompletionResponse {
447 choice,
448 usage,
449 raw_response: response,
450 message_id: None,
451 })
452 }
453}
454
455#[derive(Debug, Serialize, Deserialize)]
456pub(super) struct DeepseekCompletionRequest {
457 model: String,
458 pub messages: Vec<Message>,
459 #[serde(skip_serializing_if = "Option::is_none")]
460 temperature: Option<f64>,
461 #[serde(skip_serializing_if = "Vec::is_empty")]
462 tools: Vec<ToolDefinition>,
463 #[serde(skip_serializing_if = "Option::is_none")]
464 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
465 #[serde(flatten, skip_serializing_if = "Option::is_none")]
466 pub additional_params: Option<serde_json::Value>,
467}
468
469impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
470 type Error = CompletionError;
471
472 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
473 if req.output_schema.is_some() {
474 tracing::warn!("Structured outputs currently not supported for DeepSeek");
475 }
476 let model = req.model.clone().unwrap_or_else(|| model.to_string());
477 let mut full_history: Vec<Message> = match &req.preamble {
478 Some(preamble) => vec![Message::system(preamble)],
479 None => vec![],
480 };
481
482 if let Some(docs) = req.normalized_documents() {
483 let docs: Vec<Message> = docs.try_into()?;
484 full_history.extend(docs);
485 }
486
487 let chat_history: Vec<Message> = req
488 .chat_history
489 .clone()
490 .into_iter()
491 .map(|message| message.try_into())
492 .collect::<Result<Vec<Vec<Message>>, _>>()?
493 .into_iter()
494 .flatten()
495 .collect();
496
497 full_history.extend(chat_history);
498
499 let tool_choice = req
500 .tool_choice
501 .clone()
502 .map(crate::providers::openrouter::ToolChoice::try_from)
503 .transpose()?;
504
505 Ok(Self {
506 model: model.to_string(),
507 messages: full_history,
508 temperature: req.temperature,
509 tools: req
510 .tools
511 .clone()
512 .into_iter()
513 .map(ToolDefinition::from)
514 .collect::<Vec<_>>(),
515 tool_choice,
516 additional_params: req.additional_params,
517 })
518 }
519}
520
521#[derive(Clone)]
523pub struct CompletionModel<T = reqwest::Client> {
524 pub client: Client<T>,
525 pub model: String,
526}
527
528impl<T> completion::CompletionModel for CompletionModel<T>
529where
530 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
531{
532 type Response = CompletionResponse;
533 type StreamingResponse = StreamingCompletionResponse;
534
535 type Client = Client<T>;
536
537 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
538 Self {
539 client: client.clone(),
540 model: model.into().to_string(),
541 }
542 }
543
544 async fn completion(
545 &self,
546 completion_request: CompletionRequest,
547 ) -> Result<
548 completion::CompletionResponse<CompletionResponse>,
549 crate::completion::CompletionError,
550 > {
551 let span = if tracing::Span::current().is_disabled() {
552 info_span!(
553 target: "rig::completions",
554 "chat",
555 gen_ai.operation.name = "chat",
556 gen_ai.provider.name = "deepseek",
557 gen_ai.request.model = self.model,
558 gen_ai.system_instructions = tracing::field::Empty,
559 gen_ai.response.id = tracing::field::Empty,
560 gen_ai.response.model = tracing::field::Empty,
561 gen_ai.usage.output_tokens = tracing::field::Empty,
562 gen_ai.usage.input_tokens = tracing::field::Empty,
563 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
564 )
565 } else {
566 tracing::Span::current()
567 };
568
569 span.record("gen_ai.system_instructions", &completion_request.preamble);
570
571 let request =
572 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
573
574 if enabled!(Level::TRACE) {
575 tracing::trace!(target: "rig::completions",
576 "DeepSeek completion request: {}",
577 serde_json::to_string_pretty(&request)?
578 );
579 }
580
581 let body = serde_json::to_vec(&request)?;
582 let req = self
583 .client
584 .post("/chat/completions")?
585 .body(body)
586 .map_err(|e| CompletionError::HttpError(e.into()))?;
587
588 async move {
589 let response = self.client.send::<_, Bytes>(req).await?;
590 let status = response.status();
591 let response_body = response.into_body().into_future().await?.to_vec();
592
593 if status.is_success() {
594 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
595 ApiResponse::Ok(response) => {
596 let span = tracing::Span::current();
597 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
598 span.record(
599 "gen_ai.usage.output_tokens",
600 response.usage.completion_tokens,
601 );
602 span.record(
603 "gen_ai.usage.cache_read.input_tokens",
604 response
605 .usage
606 .prompt_tokens_details
607 .as_ref()
608 .and_then(|d| d.cached_tokens)
609 .unwrap_or(0),
610 );
611 if enabled!(Level::TRACE) {
612 tracing::trace!(target: "rig::completions",
613 "DeepSeek completion response: {}",
614 serde_json::to_string_pretty(&response)?
615 );
616 }
617 response.try_into()
618 }
619 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
620 }
621 } else {
622 Err(CompletionError::ProviderError(
623 String::from_utf8_lossy(&response_body).to_string(),
624 ))
625 }
626 }
627 .instrument(span)
628 .await
629 }
630
631 async fn stream(
632 &self,
633 completion_request: CompletionRequest,
634 ) -> Result<
635 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
636 CompletionError,
637 > {
638 let preamble = completion_request.preamble.clone();
639 let mut request =
640 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
641
642 let params = json_utils::merge(
643 request.additional_params.unwrap_or(serde_json::json!({})),
644 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
645 );
646
647 request.additional_params = Some(params);
648
649 if enabled!(Level::TRACE) {
650 tracing::trace!(target: "rig::completions",
651 "DeepSeek streaming completion request: {}",
652 serde_json::to_string_pretty(&request)?
653 );
654 }
655
656 let body = serde_json::to_vec(&request)?;
657
658 let req = self
659 .client
660 .post("/chat/completions")?
661 .body(body)
662 .map_err(|e| CompletionError::HttpError(e.into()))?;
663
664 let span = if tracing::Span::current().is_disabled() {
665 info_span!(
666 target: "rig::completions",
667 "chat_streaming",
668 gen_ai.operation.name = "chat_streaming",
669 gen_ai.provider.name = "deepseek",
670 gen_ai.request.model = self.model,
671 gen_ai.system_instructions = preamble,
672 gen_ai.response.id = tracing::field::Empty,
673 gen_ai.response.model = tracing::field::Empty,
674 gen_ai.usage.output_tokens = tracing::field::Empty,
675 gen_ai.usage.input_tokens = tracing::field::Empty,
676 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
677 )
678 } else {
679 tracing::Span::current()
680 };
681
682 tracing::Instrument::instrument(
683 send_compatible_streaming_request(self.client.clone(), req),
684 span,
685 )
686 .await
687 }
688}
689
690#[derive(Deserialize, Debug)]
691pub struct StreamingDelta {
692 #[serde(default)]
693 content: Option<String>,
694 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
695 tool_calls: Vec<StreamingToolCall>,
696 reasoning_content: Option<String>,
697}
698
699#[derive(Deserialize, Debug)]
700struct StreamingChoice {
701 delta: StreamingDelta,
702}
703
704#[derive(Deserialize, Debug)]
705struct StreamingCompletionChunk {
706 id: Option<String>,
707 model: Option<String>,
708 choices: Vec<StreamingChoice>,
709 usage: Option<Usage>,
710}
711
712#[derive(Clone, Deserialize, Serialize, Debug)]
713pub struct StreamingCompletionResponse {
714 pub usage: Usage,
715}
716
717impl GetTokenUsage for StreamingCompletionResponse {
718 fn token_usage(&self) -> Option<crate::completion::Usage> {
719 self.usage.token_usage()
720 }
721}
722
723#[derive(Clone, Copy)]
724struct DeepSeekCompatibleProfile;
725
726impl CompatibleStreamProfile for DeepSeekCompatibleProfile {
727 type Usage = Usage;
728 type Detail = ();
729 type FinalResponse = StreamingCompletionResponse;
730
731 fn normalize_chunk(
732 &self,
733 data: &str,
734 ) -> Result<Option<CompatibleChunk<Self::Usage, Self::Detail>>, CompletionError> {
735 let data = match serde_json::from_str::<StreamingCompletionChunk>(data) {
736 Ok(data) => data,
737 Err(error) => {
738 tracing::debug!(
739 "Couldn't parse SSE payload as StreamingCompletionChunk: {:?}",
740 error
741 );
742 return Ok(None);
743 }
744 };
745
746 Ok(Some(
747 openai_chat_completions_compatible::normalize_first_choice_chunk(
748 data.id,
749 data.model,
750 data.usage,
751 &data.choices,
752 |choice| CompatibleChoiceData {
753 finish_reason: CompatibleFinishReason::Other,
754 text: choice.delta.content.clone(),
755 reasoning: choice.delta.reasoning_content.clone(),
756 tool_calls: openai_chat_completions_compatible::tool_call_chunks(
757 &choice.delta.tool_calls,
758 ),
759 details: Vec::new(),
760 },
761 ),
762 ))
763 }
764
765 fn build_final_response(&self, usage: Self::Usage) -> Self::FinalResponse {
766 StreamingCompletionResponse { usage }
767 }
768
769 fn uses_distinct_tool_call_eviction(&self) -> bool {
770 true
771 }
772
773 fn emits_complete_single_chunk_tool_calls(&self) -> bool {
774 true
775 }
776}
777
778pub async fn send_compatible_streaming_request<T>(
779 http_client: T,
780 req: Request<Vec<u8>>,
781) -> Result<
782 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
783 CompletionError,
784>
785where
786 T: HttpClientExt + Clone + 'static,
787{
788 openai_chat_completions_compatible::send_compatible_streaming_request(
789 http_client,
790 req,
791 DeepSeekCompatibleProfile,
792 )
793 .await
794}
795
796#[derive(Debug, Deserialize)]
797struct ListModelsResponse {
798 data: Vec<ListModelEntry>,
799}
800
801#[derive(Debug, Deserialize)]
802struct ListModelEntry {
803 id: String,
804 owned_by: String,
805}
806
807impl From<ListModelEntry> for Model {
808 fn from(value: ListModelEntry) -> Self {
809 let mut model = Model::from_id(value.id);
810 model.owned_by = Some(value.owned_by);
811 model
812 }
813}
814
815#[derive(Clone)]
817pub struct DeepSeekModelLister<H = reqwest::Client> {
818 client: Client<H>,
819}
820
821impl<H> ModelLister<H> for DeepSeekModelLister<H>
822where
823 H: HttpClientExt + WasmCompatSend + WasmCompatSync + 'static,
824{
825 type Client = Client<H>;
826
827 fn new(client: Self::Client) -> Self {
828 Self { client }
829 }
830
831 async fn list_all(&self) -> Result<ModelList, ModelListingError> {
832 let path = "/models";
833 let req = self.client.get(path)?.body(http_client::NoBody)?;
834 let response = self
835 .client
836 .send::<_, Vec<u8>>(req)
837 .await
838 .map_err(|error| match error {
839 http_client::Error::InvalidStatusCodeWithMessage(status, message) => {
840 ModelListingError::api_error_with_context(
841 "DeepSeek",
842 path,
843 status.as_u16(),
844 message.as_bytes(),
845 )
846 }
847 other => ModelListingError::from(other),
848 })?;
849
850 if !response.status().is_success() {
851 let status_code = response.status().as_u16();
852 let body = response.into_body().await?;
853 return Err(ModelListingError::api_error_with_context(
854 "DeepSeek",
855 path,
856 status_code,
857 &body,
858 ));
859 }
860
861 let body = response.into_body().await?;
862 let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
863 ModelListingError::parse_error_with_context("DeepSeek", path, &error, &body)
864 })?;
865
866 let models = api_resp.data.into_iter().map(Model::from).collect();
867
868 Ok(ModelList::new(models))
869 }
870}
871
872#[deprecated(
876 note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
877 For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
878 respectively."
879)]
880pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
881#[deprecated(
882 note = "The model names `deepseek-chat` and `deepseek-reasoner` will be deprecated on 2026/07/24. \
883 For compatibility, they correspond to the non-thinking mode and thinking mode of `deepseek-v4-flash`, \
884 respectively."
885)]
886pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
887pub const DEEPSEEK_V4_FLASH: &str = "deepseek-v4-flash";
888pub const DEEPSEEK_V4_PRO: &str = "deepseek-v4-pro";
889
890#[cfg(test)]
892mod tests {
893 use super::*;
894 use crate::client::ModelListingClient;
895 use crate::http_client::{LazyBody, MultipartForm, Request as HttpRequest, Response};
896 use bytes::Bytes;
897 use std::future::{self, Future};
898 use std::sync::{Arc, Mutex};
899
900 #[test]
901 fn test_deserialize_vec_choice() {
902 let data = r#"[{
903 "finish_reason": "stop",
904 "index": 0,
905 "logprobs": null,
906 "message":{"role":"assistant","content":"Hello, world!"}
907 }]"#;
908
909 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
910 assert_eq!(choices.len(), 1);
911 match &choices.first().unwrap().message {
912 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
913 _ => panic!("Expected assistant message"),
914 }
915 }
916
917 #[test]
918 fn test_deserialize_deepseek_response() {
919 let data = r#"{
920 "choices":[{
921 "finish_reason": "stop",
922 "index": 0,
923 "logprobs": null,
924 "message":{"role":"assistant","content":"Hello, world!"}
925 }],
926 "usage": {
927 "completion_tokens": 0,
928 "prompt_tokens": 0,
929 "prompt_cache_hit_tokens": 0,
930 "prompt_cache_miss_tokens": 0,
931 "total_tokens": 0
932 }
933 }"#;
934
935 let jd = &mut serde_json::Deserializer::from_str(data);
936 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
937 match result {
938 Ok(response) => match &response.choices.first().unwrap().message {
939 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
940 _ => panic!("Expected assistant message"),
941 },
942 Err(err) => {
943 panic!("Deserialization error at {}: {}", err.path(), err);
944 }
945 }
946 }
947
948 #[test]
949 fn test_deserialize_example_response() {
950 let data = r#"
951 {
952 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
953 "object": "chat.completion",
954 "created": 0,
955 "model": "deepseek-chat",
956 "choices": [
957 {
958 "index": 0,
959 "message": {
960 "role": "assistant",
961 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
962 },
963 "logprobs": null,
964 "finish_reason": "stop"
965 }
966 ],
967 "usage": {
968 "prompt_tokens": 13,
969 "completion_tokens": 32,
970 "total_tokens": 45,
971 "prompt_tokens_details": {
972 "cached_tokens": 0
973 },
974 "prompt_cache_hit_tokens": 0,
975 "prompt_cache_miss_tokens": 13
976 },
977 "system_fingerprint": "fp_4b6881f2c5"
978 }
979 "#;
980 let jd = &mut serde_json::Deserializer::from_str(data);
981 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
982
983 match result {
984 Ok(response) => match &response.choices.first().unwrap().message {
985 Message::Assistant { content, .. } => assert_eq!(
986 content,
987 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
988 ),
989 _ => panic!("Expected assistant message"),
990 },
991 Err(err) => {
992 panic!("Deserialization error at {}: {}", err.path(), err);
993 }
994 }
995 }
996
997 #[test]
998 fn test_serialize_deserialize_tool_call_message() {
999 let tool_call_choice_json = r#"
1000 {
1001 "finish_reason": "tool_calls",
1002 "index": 0,
1003 "logprobs": null,
1004 "message": {
1005 "content": "",
1006 "role": "assistant",
1007 "tool_calls": [
1008 {
1009 "function": {
1010 "arguments": "{\"x\":2,\"y\":5}",
1011 "name": "subtract"
1012 },
1013 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1014 "index": 0,
1015 "type": "function"
1016 }
1017 ]
1018 }
1019 }
1020 "#;
1021
1022 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1023
1024 let expected_choice: Choice = Choice {
1025 finish_reason: "tool_calls".to_string(),
1026 index: 0,
1027 logprobs: None,
1028 message: Message::Assistant {
1029 content: "".to_string(),
1030 name: None,
1031 tool_calls: vec![ToolCall {
1032 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1033 function: Function {
1034 name: "subtract".to_string(),
1035 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1036 },
1037 index: 0,
1038 r#type: ToolType::Function,
1039 }],
1040 reasoning_content: None,
1041 },
1042 };
1043
1044 assert_eq!(choice, expected_choice);
1045 }
1046 #[test]
1047 fn test_user_message_multiple_text_items_merged() {
1048 use crate::completion::message::{Message as RigMessage, UserContent};
1049
1050 let rig_msg = RigMessage::User {
1051 content: OneOrMany::many(vec![
1052 UserContent::text("first part"),
1053 UserContent::text("second part"),
1054 ])
1055 .expect("content should not be empty"),
1056 };
1057
1058 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1059
1060 let user_messages: Vec<&Message> = messages
1061 .iter()
1062 .filter(|m| matches!(m, Message::User { .. }))
1063 .collect();
1064
1065 assert_eq!(
1066 user_messages.len(),
1067 1,
1068 "multiple text items should produce a single user message"
1069 );
1070 match &user_messages[0] {
1071 Message::User { content, .. } => {
1072 assert_eq!(content, "first part\nsecond part");
1073 }
1074 _ => unreachable!(),
1075 }
1076 }
1077
1078 #[test]
1079 fn test_assistant_message_with_reasoning_and_tool_calls() {
1080 use crate::completion::message::{AssistantContent, Message as RigMessage};
1081
1082 let rig_msg = RigMessage::Assistant {
1083 id: None,
1084 content: OneOrMany::many(vec![
1085 AssistantContent::reasoning("thinking about the problem"),
1086 AssistantContent::text("I'll call the tool"),
1087 AssistantContent::tool_call(
1088 "call_1",
1089 "subtract",
1090 serde_json::json!({"x": 2, "y": 5}),
1091 ),
1092 ])
1093 .expect("content should not be empty"),
1094 };
1095
1096 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1097
1098 assert_eq!(messages.len(), 1, "should produce exactly one message");
1099 match &messages[0] {
1100 Message::Assistant {
1101 content,
1102 tool_calls,
1103 reasoning_content,
1104 ..
1105 } => {
1106 assert_eq!(content, "I'll call the tool");
1107 assert_eq!(
1108 reasoning_content.as_deref(),
1109 Some("thinking about the problem")
1110 );
1111 assert_eq!(tool_calls.len(), 1);
1112 assert_eq!(tool_calls[0].function.name, "subtract");
1113 }
1114 _ => panic!("Expected assistant message"),
1115 }
1116 }
1117
1118 #[test]
1119 fn test_assistant_message_without_reasoning() {
1120 use crate::completion::message::{AssistantContent, Message as RigMessage};
1121
1122 let rig_msg = RigMessage::Assistant {
1123 id: None,
1124 content: OneOrMany::many(vec![
1125 AssistantContent::text("calling tool"),
1126 AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1127 ])
1128 .expect("content should not be empty"),
1129 };
1130
1131 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1132
1133 assert_eq!(messages.len(), 1);
1134 match &messages[0] {
1135 Message::Assistant {
1136 reasoning_content,
1137 tool_calls,
1138 ..
1139 } => {
1140 assert!(reasoning_content.is_none());
1141 assert_eq!(tool_calls.len(), 1);
1142 }
1143 _ => panic!("Expected assistant message"),
1144 }
1145 }
1146
1147 #[test]
1148 fn test_client_initialization() {
1149 let _client =
1150 crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1151 let _client_from_builder = crate::providers::deepseek::Client::builder()
1152 .api_key("dummy-key")
1153 .build()
1154 .expect("Client::builder() failed");
1155 }
1156
1157 #[test]
1158 fn test_deserialize_list_models_response() {
1159 let data = r#"{
1160 "object": "list",
1161 "data": [
1162 {
1163 "id": "deepseek-v4-flash",
1164 "object": "model",
1165 "owned_by": "deepseek"
1166 },
1167 {
1168 "id": "deepseek-v4-pro",
1169 "object": "model",
1170 "owned_by": "deepseek"
1171 }
1172 ]
1173 }"#;
1174
1175 let response: ListModelsResponse = serde_json::from_str(data).unwrap();
1176
1177 assert_eq!(response.data.len(), 2);
1178 assert_eq!(response.data[0].id, "deepseek-v4-flash");
1179 assert_eq!(response.data[0].owned_by, "deepseek");
1180 }
1181
1182 #[derive(Debug, Clone, PartialEq, Eq)]
1183 struct CapturedRequest {
1184 uri: String,
1185 }
1186
1187 #[derive(Clone)]
1188 enum MockResponse {
1189 Success(Bytes),
1190 Error(http::StatusCode, String),
1191 }
1192
1193 impl Default for MockResponse {
1194 fn default() -> Self {
1195 Self::Success(Bytes::new())
1196 }
1197 }
1198
1199 #[derive(Clone, Default)]
1200 struct RecordingHttpClient {
1201 requests: Arc<Mutex<Vec<CapturedRequest>>>,
1202 response: Arc<Mutex<MockResponse>>,
1203 }
1204
1205 impl RecordingHttpClient {
1206 fn new(response_body: impl Into<Bytes>) -> Self {
1207 Self {
1208 requests: Arc::new(Mutex::new(Vec::new())),
1209 response: Arc::new(Mutex::new(MockResponse::Success(response_body.into()))),
1210 }
1211 }
1212
1213 fn with_error(status: http::StatusCode, message: impl Into<String>) -> Self {
1214 Self {
1215 requests: Arc::new(Mutex::new(Vec::new())),
1216 response: Arc::new(Mutex::new(MockResponse::Error(status, message.into()))),
1217 }
1218 }
1219
1220 fn requests(&self) -> Vec<CapturedRequest> {
1221 self.requests.lock().expect("requests lock").clone()
1222 }
1223 }
1224
1225 impl HttpClientExt for RecordingHttpClient {
1226 fn send<T, U>(
1227 &self,
1228 req: HttpRequest<T>,
1229 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
1230 where
1231 T: Into<Bytes> + WasmCompatSend,
1232 U: From<Bytes> + WasmCompatSend + 'static,
1233 {
1234 let requests = Arc::clone(&self.requests);
1235 let response = self.response.lock().expect("response lock").clone();
1236 let (parts, _body) = req.into_parts();
1237
1238 requests
1239 .lock()
1240 .expect("requests lock")
1241 .push(CapturedRequest {
1242 uri: parts.uri.to_string(),
1243 });
1244
1245 async move {
1246 let response_body = match response {
1247 MockResponse::Success(response_body) => response_body,
1248 MockResponse::Error(status, message) => {
1249 return Err(http_client::Error::InvalidStatusCodeWithMessage(
1250 status, message,
1251 ));
1252 }
1253 };
1254 let body: LazyBody<U> = Box::pin(async move { Ok(U::from(response_body)) });
1255 Response::builder()
1256 .status(http::StatusCode::OK)
1257 .body(body)
1258 .map_err(http_client::Error::Protocol)
1259 }
1260 }
1261
1262 fn send_multipart<U>(
1263 &self,
1264 _req: HttpRequest<MultipartForm>,
1265 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
1266 where
1267 U: From<Bytes> + WasmCompatSend + 'static,
1268 {
1269 future::ready(Err(http_client::Error::InvalidStatusCode(
1270 http::StatusCode::NOT_IMPLEMENTED,
1271 )))
1272 }
1273
1274 fn send_streaming<T>(
1275 &self,
1276 _req: HttpRequest<T>,
1277 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
1278 where
1279 T: Into<Bytes> + WasmCompatSend,
1280 {
1281 future::ready(Err(http_client::Error::InvalidStatusCode(
1282 http::StatusCode::NOT_IMPLEMENTED,
1283 )))
1284 }
1285 }
1286
1287 #[tokio::test]
1288 async fn test_list_models_uses_models_endpoint() {
1289 let response_body = r#"{
1290 "object": "list",
1291 "data": [
1292 {
1293 "id": "deepseek-v4-flash",
1294 "object": "model",
1295 "owned_by": "deepseek"
1296 },
1297 {
1298 "id": "deepseek-v4-pro",
1299 "object": "model",
1300 "owned_by": "deepseek"
1301 }
1302 ]
1303 }"#;
1304
1305 let http_client = RecordingHttpClient::new(response_body);
1306 let client = Client::builder()
1307 .api_key("dummy-key")
1308 .http_client(http_client.clone())
1309 .build()
1310 .expect("client should build");
1311
1312 let models = client
1313 .list_models()
1314 .await
1315 .expect("list_models should succeed");
1316
1317 assert_eq!(models.len(), 2);
1318 assert_eq!(models.data[0].id, "deepseek-v4-flash");
1319 assert_eq!(models.data[0].r#type, None);
1320 assert_eq!(models.data[0].owned_by.as_deref(), Some("deepseek"));
1321 assert_eq!(
1322 http_client.requests(),
1323 vec![CapturedRequest {
1324 uri: "https://api.deepseek.com/models".to_string()
1325 }]
1326 );
1327 }
1328
1329 #[tokio::test]
1330 async fn test_list_models_preserves_api_error_context() {
1331 let http_client = RecordingHttpClient::with_error(
1332 http::StatusCode::UNAUTHORIZED,
1333 r#"{"error":{"message":"invalid api key"}}"#,
1334 );
1335 let client = Client::builder()
1336 .api_key("dummy-key")
1337 .http_client(http_client)
1338 .build()
1339 .expect("client should build");
1340
1341 let error = client
1342 .list_models()
1343 .await
1344 .expect_err("list_models should fail");
1345
1346 match error {
1347 ModelListingError::ApiError {
1348 status_code,
1349 message,
1350 } => {
1351 assert_eq!(status_code, 401);
1352 assert!(message.contains("provider=DeepSeek"));
1353 assert!(message.contains("path=/models"));
1354 assert!(message.contains("invalid api key"));
1355 }
1356 other => panic!("expected api error, got {other:?}"),
1357 }
1358 }
1359}