1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51 const VERIFY_PATH: &'static str = "/user/balance";
52}
53
54impl<H> Capabilities<H> for DeepSeekExt {
55 type Completion = Capable<CompletionModel<H>>;
56 type Embeddings = Nothing;
57 type Transcription = Nothing;
58 type ModelListing = Nothing;
59 #[cfg(feature = "image")]
60 type ImageGeneration = Nothing;
61 #[cfg(feature = "audio")]
62 type AudioGeneration = Nothing;
63}
64
65impl DebugExt for DeepSeekExt {}
66
67impl ProviderBuilder for DeepSeekExtBuilder {
68 type Extension<H>
69 = DeepSeekExt
70 where
71 H: HttpClientExt;
72 type ApiKey = DeepSeekApiKey;
73
74 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
75
76 fn build<H>(
77 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
78 ) -> http_client::Result<Self::Extension<H>>
79 where
80 H: HttpClientExt,
81 {
82 Ok(DeepSeekExt)
83 }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
87pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
88
89impl ProviderClient for Client {
90 type Input = DeepSeekApiKey;
91
92 fn from_env() -> Self {
94 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
95 let mut client_builder = Self::builder();
96 client_builder.headers_mut().insert(
97 http::header::CONTENT_TYPE,
98 http::HeaderValue::from_static("application/json"),
99 );
100 let client_builder = client_builder.api_key(&api_key);
101 client_builder.build().unwrap()
102 }
103
104 fn from_val(input: Self::Input) -> Self {
105 Self::new(input).unwrap()
106 }
107}
108
109#[derive(Debug, Deserialize)]
110struct ApiErrorResponse {
111 message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117 Ok(T),
118 Err(ApiErrorResponse),
119}
120
121impl From<ApiErrorResponse> for CompletionError {
122 fn from(err: ApiErrorResponse) -> Self {
123 CompletionError::ProviderError(err.message)
124 }
125}
126
127#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct CompletionResponse {
130 pub choices: Vec<Choice>,
132 pub usage: Usage,
133 }
135
136#[derive(Clone, Debug, Serialize, Deserialize, Default)]
137pub struct Usage {
138 pub completion_tokens: u32,
139 pub prompt_tokens: u32,
140 pub prompt_cache_hit_tokens: u32,
141 pub prompt_cache_miss_tokens: u32,
142 pub total_tokens: u32,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub completion_tokens_details: Option<CompletionTokensDetails>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 pub prompt_tokens_details: Option<PromptTokensDetails>,
147}
148
149impl Usage {
150 fn new() -> Self {
151 Self {
152 completion_tokens: 0,
153 prompt_tokens: 0,
154 prompt_cache_hit_tokens: 0,
155 prompt_cache_miss_tokens: 0,
156 total_tokens: 0,
157 completion_tokens_details: None,
158 prompt_tokens_details: None,
159 }
160 }
161}
162
163#[derive(Clone, Debug, Serialize, Deserialize, Default)]
164pub struct CompletionTokensDetails {
165 #[serde(skip_serializing_if = "Option::is_none")]
166 pub reasoning_tokens: Option<u32>,
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default)]
170pub struct PromptTokensDetails {
171 #[serde(skip_serializing_if = "Option::is_none")]
172 pub cached_tokens: Option<u32>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
176pub struct Choice {
177 pub index: usize,
178 pub message: Message,
179 pub logprobs: Option<serde_json::Value>,
180 pub finish_reason: String,
181}
182
183#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
184#[serde(tag = "role", rename_all = "lowercase")]
185pub enum Message {
186 System {
187 content: String,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 name: Option<String>,
190 },
191 User {
192 content: String,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 name: Option<String>,
195 },
196 Assistant {
197 content: String,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 name: Option<String>,
200 #[serde(
201 default,
202 deserialize_with = "json_utils::null_or_vec",
203 skip_serializing_if = "Vec::is_empty"
204 )]
205 tool_calls: Vec<ToolCall>,
206 #[serde(skip_serializing_if = "Option::is_none")]
208 reasoning_content: Option<String>,
209 },
210 #[serde(rename = "tool")]
211 ToolResult {
212 tool_call_id: String,
213 content: String,
214 },
215}
216
217impl Message {
218 pub fn system(content: &str) -> Self {
219 Message::System {
220 content: content.to_owned(),
221 name: None,
222 }
223 }
224}
225
226impl From<message::ToolResult> for Message {
227 fn from(tool_result: message::ToolResult) -> Self {
228 let content = match tool_result.content.first() {
229 message::ToolResultContent::Text(text) => text.text,
230 message::ToolResultContent::Image(_) => String::from("[Image]"),
231 };
232
233 Message::ToolResult {
234 tool_call_id: tool_result.id,
235 content,
236 }
237 }
238}
239
240impl From<message::ToolCall> for ToolCall {
241 fn from(tool_call: message::ToolCall) -> Self {
242 Self {
243 id: tool_call.id,
244 index: 0,
246 r#type: ToolType::Function,
247 function: Function {
248 name: tool_call.function.name,
249 arguments: tool_call.function.arguments,
250 },
251 }
252 }
253}
254
255impl TryFrom<message::Message> for Vec<Message> {
256 type Error = message::MessageError;
257
258 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
259 match message {
260 message::Message::User { content } => {
261 let mut messages = vec![];
263
264 let tool_results = content
265 .clone()
266 .into_iter()
267 .filter_map(|content| match content {
268 message::UserContent::ToolResult(tool_result) => {
269 Some(Message::from(tool_result))
270 }
271 _ => None,
272 })
273 .collect::<Vec<_>>();
274
275 messages.extend(tool_results);
276
277 let text_content: String = content
278 .into_iter()
279 .filter_map(|content| match content {
280 message::UserContent::Text(text) => Some(text.text),
281 message::UserContent::Document(Document {
282 data:
283 DocumentSourceKind::Base64(content)
284 | DocumentSourceKind::String(content),
285 ..
286 }) => Some(content),
287 _ => None,
288 })
289 .collect::<Vec<_>>()
290 .join("\n");
291
292 if !text_content.is_empty() {
293 messages.push(Message::User {
294 content: text_content,
295 name: None,
296 });
297 }
298
299 Ok(messages)
300 }
301 message::Message::Assistant { content, .. } => {
302 let mut text_content = String::new();
303 let mut reasoning_content = String::new();
304 let mut tool_calls = Vec::new();
305
306 for item in content.iter() {
307 match item {
308 message::AssistantContent::Text(text) => {
309 text_content.push_str(text.text());
310 }
311 message::AssistantContent::Reasoning(reasoning) => {
312 reasoning_content.push_str(&reasoning.display_text());
313 }
314 message::AssistantContent::ToolCall(tool_call) => {
315 tool_calls.push(ToolCall::from(tool_call.clone()));
316 }
317 _ => {}
318 }
319 }
320
321 let reasoning = if reasoning_content.is_empty() {
322 None
323 } else {
324 Some(reasoning_content)
325 };
326
327 Ok(vec![Message::Assistant {
328 content: text_content,
329 name: None,
330 tool_calls,
331 reasoning_content: reasoning,
332 }])
333 }
334 }
335 }
336}
337
338#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
339pub struct ToolCall {
340 pub id: String,
341 pub index: usize,
342 #[serde(default)]
343 pub r#type: ToolType,
344 pub function: Function,
345}
346
347#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
348pub struct Function {
349 pub name: String,
350 #[serde(with = "json_utils::stringified_json")]
351 pub arguments: serde_json::Value,
352}
353
354#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
355#[serde(rename_all = "lowercase")]
356pub enum ToolType {
357 #[default]
358 Function,
359}
360
361#[derive(Clone, Debug, Deserialize, Serialize)]
362pub struct ToolDefinition {
363 pub r#type: String,
364 pub function: completion::ToolDefinition,
365}
366
367impl From<crate::completion::ToolDefinition> for ToolDefinition {
368 fn from(tool: crate::completion::ToolDefinition) -> Self {
369 Self {
370 r#type: "function".into(),
371 function: tool,
372 }
373 }
374}
375
376impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
377 type Error = CompletionError;
378
379 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
380 let choice = response.choices.first().ok_or_else(|| {
381 CompletionError::ResponseError("Response contained no choices".to_owned())
382 })?;
383 let content = match &choice.message {
384 Message::Assistant {
385 content,
386 tool_calls,
387 reasoning_content,
388 ..
389 } => {
390 let mut content = if content.trim().is_empty() {
391 vec![]
392 } else {
393 vec![completion::AssistantContent::text(content)]
394 };
395
396 content.extend(
397 tool_calls
398 .iter()
399 .map(|call| {
400 completion::AssistantContent::tool_call(
401 &call.id,
402 &call.function.name,
403 call.function.arguments.clone(),
404 )
405 })
406 .collect::<Vec<_>>(),
407 );
408
409 if let Some(reasoning_content) = reasoning_content {
410 content.push(completion::AssistantContent::reasoning(reasoning_content));
411 }
412
413 Ok(content)
414 }
415 _ => Err(CompletionError::ResponseError(
416 "Response did not contain a valid message or tool call".into(),
417 )),
418 }?;
419
420 let choice = OneOrMany::many(content).map_err(|_| {
421 CompletionError::ResponseError(
422 "Response contained no message or tool call (empty)".to_owned(),
423 )
424 })?;
425
426 let usage = completion::Usage {
427 input_tokens: response.usage.prompt_tokens as u64,
428 output_tokens: response.usage.completion_tokens as u64,
429 total_tokens: response.usage.total_tokens as u64,
430 cached_input_tokens: response
431 .usage
432 .prompt_tokens_details
433 .as_ref()
434 .and_then(|d| d.cached_tokens)
435 .map(|c| c as u64)
436 .unwrap_or(0),
437 };
438
439 Ok(completion::CompletionResponse {
440 choice,
441 usage,
442 raw_response: response,
443 message_id: None,
444 })
445 }
446}
447
448#[derive(Debug, Serialize, Deserialize)]
449pub(super) struct DeepseekCompletionRequest {
450 model: String,
451 pub messages: Vec<Message>,
452 #[serde(skip_serializing_if = "Option::is_none")]
453 temperature: Option<f64>,
454 #[serde(skip_serializing_if = "Vec::is_empty")]
455 tools: Vec<ToolDefinition>,
456 #[serde(skip_serializing_if = "Option::is_none")]
457 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
458 #[serde(flatten, skip_serializing_if = "Option::is_none")]
459 pub additional_params: Option<serde_json::Value>,
460}
461
462impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
463 type Error = CompletionError;
464
465 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
466 if req.output_schema.is_some() {
467 tracing::warn!("Structured outputs currently not supported for DeepSeek");
468 }
469 let model = req.model.clone().unwrap_or_else(|| model.to_string());
470 let mut full_history: Vec<Message> = match &req.preamble {
471 Some(preamble) => vec![Message::system(preamble)],
472 None => vec![],
473 };
474
475 if let Some(docs) = req.normalized_documents() {
476 let docs: Vec<Message> = docs.try_into()?;
477 full_history.extend(docs);
478 }
479
480 let chat_history: Vec<Message> = req
481 .chat_history
482 .clone()
483 .into_iter()
484 .map(|message| message.try_into())
485 .collect::<Result<Vec<Vec<Message>>, _>>()?
486 .into_iter()
487 .flatten()
488 .collect();
489
490 full_history.extend(chat_history);
491
492 let tool_choice = req
493 .tool_choice
494 .clone()
495 .map(crate::providers::openrouter::ToolChoice::try_from)
496 .transpose()?;
497
498 Ok(Self {
499 model: model.to_string(),
500 messages: full_history,
501 temperature: req.temperature,
502 tools: req
503 .tools
504 .clone()
505 .into_iter()
506 .map(ToolDefinition::from)
507 .collect::<Vec<_>>(),
508 tool_choice,
509 additional_params: req.additional_params,
510 })
511 }
512}
513
514#[derive(Clone)]
516pub struct CompletionModel<T = reqwest::Client> {
517 pub client: Client<T>,
518 pub model: String,
519}
520
521impl<T> completion::CompletionModel for CompletionModel<T>
522where
523 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
524{
525 type Response = CompletionResponse;
526 type StreamingResponse = StreamingCompletionResponse;
527
528 type Client = Client<T>;
529
530 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
531 Self {
532 client: client.clone(),
533 model: model.into().to_string(),
534 }
535 }
536
537 async fn completion(
538 &self,
539 completion_request: CompletionRequest,
540 ) -> Result<
541 completion::CompletionResponse<CompletionResponse>,
542 crate::completion::CompletionError,
543 > {
544 let span = if tracing::Span::current().is_disabled() {
545 info_span!(
546 target: "rig::completions",
547 "chat",
548 gen_ai.operation.name = "chat",
549 gen_ai.provider.name = "deepseek",
550 gen_ai.request.model = self.model,
551 gen_ai.system_instructions = tracing::field::Empty,
552 gen_ai.response.id = tracing::field::Empty,
553 gen_ai.response.model = tracing::field::Empty,
554 gen_ai.usage.output_tokens = tracing::field::Empty,
555 gen_ai.usage.input_tokens = tracing::field::Empty,
556 )
557 } else {
558 tracing::Span::current()
559 };
560
561 span.record("gen_ai.system_instructions", &completion_request.preamble);
562
563 let request =
564 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
565
566 if enabled!(Level::TRACE) {
567 tracing::trace!(target: "rig::completions",
568 "DeepSeek completion request: {}",
569 serde_json::to_string_pretty(&request)?
570 );
571 }
572
573 let body = serde_json::to_vec(&request)?;
574 let req = self
575 .client
576 .post("/chat/completions")?
577 .body(body)
578 .map_err(|e| CompletionError::HttpError(e.into()))?;
579
580 async move {
581 let response = self.client.send::<_, Bytes>(req).await?;
582 let status = response.status();
583 let response_body = response.into_body().into_future().await?.to_vec();
584
585 if status.is_success() {
586 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
587 ApiResponse::Ok(response) => {
588 let span = tracing::Span::current();
589 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
590 span.record(
591 "gen_ai.usage.output_tokens",
592 response.usage.completion_tokens,
593 );
594 if enabled!(Level::TRACE) {
595 tracing::trace!(target: "rig::completions",
596 "DeepSeek completion response: {}",
597 serde_json::to_string_pretty(&response)?
598 );
599 }
600 response.try_into()
601 }
602 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
603 }
604 } else {
605 Err(CompletionError::ProviderError(
606 String::from_utf8_lossy(&response_body).to_string(),
607 ))
608 }
609 }
610 .instrument(span)
611 .await
612 }
613
614 async fn stream(
615 &self,
616 completion_request: CompletionRequest,
617 ) -> Result<
618 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
619 CompletionError,
620 > {
621 let preamble = completion_request.preamble.clone();
622 let mut request =
623 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
624
625 let params = json_utils::merge(
626 request.additional_params.unwrap_or(serde_json::json!({})),
627 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
628 );
629
630 request.additional_params = Some(params);
631
632 if enabled!(Level::TRACE) {
633 tracing::trace!(target: "rig::completions",
634 "DeepSeek streaming completion request: {}",
635 serde_json::to_string_pretty(&request)?
636 );
637 }
638
639 let body = serde_json::to_vec(&request)?;
640
641 let req = self
642 .client
643 .post("/chat/completions")?
644 .body(body)
645 .map_err(|e| CompletionError::HttpError(e.into()))?;
646
647 let span = if tracing::Span::current().is_disabled() {
648 info_span!(
649 target: "rig::completions",
650 "chat_streaming",
651 gen_ai.operation.name = "chat_streaming",
652 gen_ai.provider.name = "deepseek",
653 gen_ai.request.model = self.model,
654 gen_ai.system_instructions = preamble,
655 gen_ai.response.id = tracing::field::Empty,
656 gen_ai.response.model = tracing::field::Empty,
657 gen_ai.usage.output_tokens = tracing::field::Empty,
658 gen_ai.usage.input_tokens = tracing::field::Empty,
659 )
660 } else {
661 tracing::Span::current()
662 };
663
664 tracing::Instrument::instrument(
665 send_compatible_streaming_request(self.client.clone(), req),
666 span,
667 )
668 .await
669 }
670}
671
672#[derive(Deserialize, Debug)]
673pub struct StreamingDelta {
674 #[serde(default)]
675 content: Option<String>,
676 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
677 tool_calls: Vec<StreamingToolCall>,
678 reasoning_content: Option<String>,
679}
680
681#[derive(Deserialize, Debug)]
682struct StreamingChoice {
683 delta: StreamingDelta,
684}
685
686#[derive(Deserialize, Debug)]
687struct StreamingCompletionChunk {
688 choices: Vec<StreamingChoice>,
689 usage: Option<Usage>,
690}
691
692#[derive(Clone, Deserialize, Serialize, Debug)]
693pub struct StreamingCompletionResponse {
694 pub usage: Usage,
695}
696
697impl GetTokenUsage for StreamingCompletionResponse {
698 fn token_usage(&self) -> Option<crate::completion::Usage> {
699 let mut usage = crate::completion::Usage::new();
700 usage.input_tokens = self.usage.prompt_tokens as u64;
701 usage.output_tokens = self.usage.completion_tokens as u64;
702 usage.total_tokens = self.usage.total_tokens as u64;
703 usage.cached_input_tokens = self
704 .usage
705 .prompt_tokens_details
706 .as_ref()
707 .and_then(|d| d.cached_tokens)
708 .map(|c| c as u64)
709 .unwrap_or(0);
710
711 Some(usage)
712 }
713}
714
715pub async fn send_compatible_streaming_request<T>(
716 http_client: T,
717 req: Request<Vec<u8>>,
718) -> Result<
719 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
720 CompletionError,
721>
722where
723 T: HttpClientExt + Clone + 'static,
724{
725 let mut event_source = GenericEventSource::new(http_client, req);
726
727 let stream = stream! {
728 let mut final_usage = Usage::new();
729 let mut text_response = String::new();
730 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
731
732 while let Some(event_result) = event_source.next().await {
733 match event_result {
734 Ok(Event::Open) => {
735 tracing::trace!("SSE connection opened");
736 continue;
737 }
738 Ok(Event::Message(message)) => {
739 if message.data.trim().is_empty() || message.data == "[DONE]" {
740 continue;
741 }
742
743 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
744 let Ok(data) = parsed else {
745 let err = parsed.unwrap_err();
746 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
747 continue;
748 };
749
750 if let Some(choice) = data.choices.first() {
751 let delta = &choice.delta;
752
753 if !delta.tool_calls.is_empty() {
754 for tool_call in &delta.tool_calls {
755 let function = &tool_call.function;
756
757 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
759 && empty_or_none(&function.arguments)
760 {
761 let id = tool_call.id.clone().unwrap_or_default();
762 let name = function.name.clone().unwrap();
763 calls.insert(tool_call.index, (id, name, String::new()));
764 }
765 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
767 && let Some(arguments) = &function.arguments
768 && !arguments.is_empty()
769 {
770 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
771 let combined = format!("{}{}", existing_args, arguments);
772 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
773 } else {
774 tracing::debug!("Partial tool call received but tool call was never started.");
775 }
776 }
777 else {
779 let id = tool_call.id.clone().unwrap_or_default();
780 let name = function.name.clone().unwrap_or_default();
781 let arguments_str = function.arguments.clone().unwrap_or_default();
782
783 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
784 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
785 continue;
786 };
787
788 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
789 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
790 ));
791 }
792 }
793 }
794
795 if let Some(content) = &delta.reasoning_content {
797 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
798 id: None,
799 reasoning: content.to_string()
800 });
801 }
802
803 if let Some(content) = &delta.content {
804 text_response += content;
805 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
806 }
807 }
808
809 if let Some(usage) = data.usage {
810 final_usage = usage.clone();
811 }
812 }
813 Err(crate::http_client::Error::StreamEnded) => {
814 break;
815 }
816 Err(err) => {
817 tracing::error!(?err, "SSE error");
818 yield Err(CompletionError::ResponseError(err.to_string()));
819 break;
820 }
821 }
822 }
823
824 event_source.close();
825
826 let mut tool_calls = Vec::new();
827 for (index, (id, name, arguments)) in calls {
829 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
830 continue;
831 };
832
833 tool_calls.push(ToolCall {
834 id: id.clone(),
835 index,
836 r#type: ToolType::Function,
837 function: Function {
838 name: name.clone(),
839 arguments: arguments_json.clone()
840 }
841 });
842 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
843 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
844 ));
845 }
846
847 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
848 StreamingCompletionResponse { usage: final_usage.clone() }
849 ));
850 };
851
852 Ok(crate::streaming::StreamingCompletionResponse::stream(
853 Box::pin(stream),
854 ))
855}
856
857pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
861pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
862
863#[cfg(test)]
865mod tests {
866 use super::*;
867
868 #[test]
869 fn test_deserialize_vec_choice() {
870 let data = r#"[{
871 "finish_reason": "stop",
872 "index": 0,
873 "logprobs": null,
874 "message":{"role":"assistant","content":"Hello, world!"}
875 }]"#;
876
877 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
878 assert_eq!(choices.len(), 1);
879 match &choices.first().unwrap().message {
880 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
881 _ => panic!("Expected assistant message"),
882 }
883 }
884
885 #[test]
886 fn test_deserialize_deepseek_response() {
887 let data = r#"{
888 "choices":[{
889 "finish_reason": "stop",
890 "index": 0,
891 "logprobs": null,
892 "message":{"role":"assistant","content":"Hello, world!"}
893 }],
894 "usage": {
895 "completion_tokens": 0,
896 "prompt_tokens": 0,
897 "prompt_cache_hit_tokens": 0,
898 "prompt_cache_miss_tokens": 0,
899 "total_tokens": 0
900 }
901 }"#;
902
903 let jd = &mut serde_json::Deserializer::from_str(data);
904 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
905 match result {
906 Ok(response) => match &response.choices.first().unwrap().message {
907 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
908 _ => panic!("Expected assistant message"),
909 },
910 Err(err) => {
911 panic!("Deserialization error at {}: {}", err.path(), err);
912 }
913 }
914 }
915
916 #[test]
917 fn test_deserialize_example_response() {
918 let data = r#"
919 {
920 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
921 "object": "chat.completion",
922 "created": 0,
923 "model": "deepseek-chat",
924 "choices": [
925 {
926 "index": 0,
927 "message": {
928 "role": "assistant",
929 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
930 },
931 "logprobs": null,
932 "finish_reason": "stop"
933 }
934 ],
935 "usage": {
936 "prompt_tokens": 13,
937 "completion_tokens": 32,
938 "total_tokens": 45,
939 "prompt_tokens_details": {
940 "cached_tokens": 0
941 },
942 "prompt_cache_hit_tokens": 0,
943 "prompt_cache_miss_tokens": 13
944 },
945 "system_fingerprint": "fp_4b6881f2c5"
946 }
947 "#;
948 let jd = &mut serde_json::Deserializer::from_str(data);
949 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
950
951 match result {
952 Ok(response) => match &response.choices.first().unwrap().message {
953 Message::Assistant { content, .. } => assert_eq!(
954 content,
955 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
956 ),
957 _ => panic!("Expected assistant message"),
958 },
959 Err(err) => {
960 panic!("Deserialization error at {}: {}", err.path(), err);
961 }
962 }
963 }
964
965 #[test]
966 fn test_serialize_deserialize_tool_call_message() {
967 let tool_call_choice_json = r#"
968 {
969 "finish_reason": "tool_calls",
970 "index": 0,
971 "logprobs": null,
972 "message": {
973 "content": "",
974 "role": "assistant",
975 "tool_calls": [
976 {
977 "function": {
978 "arguments": "{\"x\":2,\"y\":5}",
979 "name": "subtract"
980 },
981 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
982 "index": 0,
983 "type": "function"
984 }
985 ]
986 }
987 }
988 "#;
989
990 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
991
992 let expected_choice: Choice = Choice {
993 finish_reason: "tool_calls".to_string(),
994 index: 0,
995 logprobs: None,
996 message: Message::Assistant {
997 content: "".to_string(),
998 name: None,
999 tool_calls: vec![ToolCall {
1000 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1001 function: Function {
1002 name: "subtract".to_string(),
1003 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1004 },
1005 index: 0,
1006 r#type: ToolType::Function,
1007 }],
1008 reasoning_content: None,
1009 },
1010 };
1011
1012 assert_eq!(choice, expected_choice);
1013 }
1014 #[test]
1015 fn test_user_message_multiple_text_items_merged() {
1016 use crate::completion::message::{Message as RigMessage, UserContent};
1017
1018 let rig_msg = RigMessage::User {
1019 content: OneOrMany::many(vec![
1020 UserContent::text("first part"),
1021 UserContent::text("second part"),
1022 ])
1023 .expect("content should not be empty"),
1024 };
1025
1026 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1027
1028 let user_messages: Vec<&Message> = messages
1029 .iter()
1030 .filter(|m| matches!(m, Message::User { .. }))
1031 .collect();
1032
1033 assert_eq!(
1034 user_messages.len(),
1035 1,
1036 "multiple text items should produce a single user message"
1037 );
1038 match &user_messages[0] {
1039 Message::User { content, .. } => {
1040 assert_eq!(content, "first part\nsecond part");
1041 }
1042 _ => unreachable!(),
1043 }
1044 }
1045
1046 #[test]
1047 fn test_assistant_message_with_reasoning_and_tool_calls() {
1048 use crate::completion::message::{AssistantContent, Message as RigMessage};
1049
1050 let rig_msg = RigMessage::Assistant {
1051 id: None,
1052 content: OneOrMany::many(vec![
1053 AssistantContent::reasoning("thinking about the problem"),
1054 AssistantContent::text("I'll call the tool"),
1055 AssistantContent::tool_call(
1056 "call_1",
1057 "subtract",
1058 serde_json::json!({"x": 2, "y": 5}),
1059 ),
1060 ])
1061 .expect("content should not be empty"),
1062 };
1063
1064 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1065
1066 assert_eq!(messages.len(), 1, "should produce exactly one message");
1067 match &messages[0] {
1068 Message::Assistant {
1069 content,
1070 tool_calls,
1071 reasoning_content,
1072 ..
1073 } => {
1074 assert_eq!(content, "I'll call the tool");
1075 assert_eq!(
1076 reasoning_content.as_deref(),
1077 Some("thinking about the problem")
1078 );
1079 assert_eq!(tool_calls.len(), 1);
1080 assert_eq!(tool_calls[0].function.name, "subtract");
1081 }
1082 _ => panic!("Expected assistant message"),
1083 }
1084 }
1085
1086 #[test]
1087 fn test_assistant_message_without_reasoning() {
1088 use crate::completion::message::{AssistantContent, Message as RigMessage};
1089
1090 let rig_msg = RigMessage::Assistant {
1091 id: None,
1092 content: OneOrMany::many(vec![
1093 AssistantContent::text("calling tool"),
1094 AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1095 ])
1096 .expect("content should not be empty"),
1097 };
1098
1099 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1100
1101 assert_eq!(messages.len(), 1);
1102 match &messages[0] {
1103 Message::Assistant {
1104 reasoning_content,
1105 tool_calls,
1106 ..
1107 } => {
1108 assert!(reasoning_content.is_none());
1109 assert_eq!(tool_calls.len(), 1);
1110 }
1111 _ => panic!("Expected assistant message"),
1112 }
1113 }
1114
1115 #[test]
1116 fn test_client_initialization() {
1117 let _client =
1118 crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1119 let _client_from_builder = crate::providers::deepseek::Client::builder()
1120 .api_key("dummy-key")
1121 .build()
1122 .expect("Client::builder() failed");
1123 }
1124}