1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51 const VERIFY_PATH: &'static str = "/user/balance";
52}
53
54impl<H> Capabilities<H> for DeepSeekExt {
55 type Completion = Capable<CompletionModel<H>>;
56 type Embeddings = Nothing;
57 type Transcription = Nothing;
58 type ModelListing = Nothing;
59 #[cfg(feature = "image")]
60 type ImageGeneration = Nothing;
61 #[cfg(feature = "audio")]
62 type AudioGeneration = Nothing;
63}
64
65impl DebugExt for DeepSeekExt {}
66
67impl ProviderBuilder for DeepSeekExtBuilder {
68 type Extension<H>
69 = DeepSeekExt
70 where
71 H: HttpClientExt;
72 type ApiKey = DeepSeekApiKey;
73
74 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
75
76 fn build<H>(
77 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
78 ) -> http_client::Result<Self::Extension<H>>
79 where
80 H: HttpClientExt,
81 {
82 Ok(DeepSeekExt)
83 }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
87pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
88
89impl ProviderClient for Client {
90 type Input = DeepSeekApiKey;
91
92 fn from_env() -> Self {
94 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
95 let mut client_builder = Self::builder();
96 client_builder.headers_mut().insert(
97 http::header::CONTENT_TYPE,
98 http::HeaderValue::from_static("application/json"),
99 );
100 let client_builder = client_builder.api_key(&api_key);
101 client_builder.build().unwrap()
102 }
103
104 fn from_val(input: Self::Input) -> Self {
105 Self::new(input).unwrap()
106 }
107}
108
109#[derive(Debug, Deserialize)]
110struct ApiErrorResponse {
111 message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117 Ok(T),
118 Err(ApiErrorResponse),
119}
120
121impl From<ApiErrorResponse> for CompletionError {
122 fn from(err: ApiErrorResponse) -> Self {
123 CompletionError::ProviderError(err.message)
124 }
125}
126
127#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct CompletionResponse {
130 pub choices: Vec<Choice>,
132 pub usage: Usage,
133 }
135
136#[derive(Clone, Debug, Serialize, Deserialize, Default)]
137pub struct Usage {
138 pub completion_tokens: u32,
139 pub prompt_tokens: u32,
140 pub prompt_cache_hit_tokens: u32,
141 pub prompt_cache_miss_tokens: u32,
142 pub total_tokens: u32,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub completion_tokens_details: Option<CompletionTokensDetails>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 pub prompt_tokens_details: Option<PromptTokensDetails>,
147}
148
149impl Usage {
150 fn new() -> Self {
151 Self {
152 completion_tokens: 0,
153 prompt_tokens: 0,
154 prompt_cache_hit_tokens: 0,
155 prompt_cache_miss_tokens: 0,
156 total_tokens: 0,
157 completion_tokens_details: None,
158 prompt_tokens_details: None,
159 }
160 }
161}
162
163#[derive(Clone, Debug, Serialize, Deserialize, Default)]
164pub struct CompletionTokensDetails {
165 #[serde(skip_serializing_if = "Option::is_none")]
166 pub reasoning_tokens: Option<u32>,
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default)]
170pub struct PromptTokensDetails {
171 #[serde(skip_serializing_if = "Option::is_none")]
172 pub cached_tokens: Option<u32>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
176pub struct Choice {
177 pub index: usize,
178 pub message: Message,
179 pub logprobs: Option<serde_json::Value>,
180 pub finish_reason: String,
181}
182
183#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
184#[serde(tag = "role", rename_all = "lowercase")]
185pub enum Message {
186 System {
187 content: String,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 name: Option<String>,
190 },
191 User {
192 content: String,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 name: Option<String>,
195 },
196 Assistant {
197 content: String,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 name: Option<String>,
200 #[serde(
201 default,
202 deserialize_with = "json_utils::null_or_vec",
203 skip_serializing_if = "Vec::is_empty"
204 )]
205 tool_calls: Vec<ToolCall>,
206 #[serde(skip_serializing_if = "Option::is_none")]
208 reasoning_content: Option<String>,
209 },
210 #[serde(rename = "tool")]
211 ToolResult {
212 tool_call_id: String,
213 content: String,
214 },
215}
216
217impl Message {
218 pub fn system(content: &str) -> Self {
219 Message::System {
220 content: content.to_owned(),
221 name: None,
222 }
223 }
224}
225
226impl From<message::ToolResult> for Message {
227 fn from(tool_result: message::ToolResult) -> Self {
228 let content = match tool_result.content.first() {
229 message::ToolResultContent::Text(text) => text.text,
230 message::ToolResultContent::Image(_) => String::from("[Image]"),
231 };
232
233 Message::ToolResult {
234 tool_call_id: tool_result.id,
235 content,
236 }
237 }
238}
239
240impl From<message::ToolCall> for ToolCall {
241 fn from(tool_call: message::ToolCall) -> Self {
242 Self {
243 id: tool_call.id,
244 index: 0,
246 r#type: ToolType::Function,
247 function: Function {
248 name: tool_call.function.name,
249 arguments: tool_call.function.arguments,
250 },
251 }
252 }
253}
254
255impl TryFrom<message::Message> for Vec<Message> {
256 type Error = message::MessageError;
257
258 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
259 match message {
260 message::Message::System { content } => Ok(vec![Message::System {
261 content,
262 name: None,
263 }]),
264 message::Message::User { content } => {
265 let mut messages = vec![];
267
268 let tool_results = content
269 .clone()
270 .into_iter()
271 .filter_map(|content| match content {
272 message::UserContent::ToolResult(tool_result) => {
273 Some(Message::from(tool_result))
274 }
275 _ => None,
276 })
277 .collect::<Vec<_>>();
278
279 messages.extend(tool_results);
280
281 let text_content: String = content
282 .into_iter()
283 .filter_map(|content| match content {
284 message::UserContent::Text(text) => Some(text.text),
285 message::UserContent::Document(Document {
286 data:
287 DocumentSourceKind::Base64(content)
288 | DocumentSourceKind::String(content),
289 ..
290 }) => Some(content),
291 _ => None,
292 })
293 .collect::<Vec<_>>()
294 .join("\n");
295
296 if !text_content.is_empty() {
297 messages.push(Message::User {
298 content: text_content,
299 name: None,
300 });
301 }
302
303 Ok(messages)
304 }
305 message::Message::Assistant { content, .. } => {
306 let mut text_content = String::new();
307 let mut reasoning_content = String::new();
308 let mut tool_calls = Vec::new();
309
310 for item in content.iter() {
311 match item {
312 message::AssistantContent::Text(text) => {
313 text_content.push_str(text.text());
314 }
315 message::AssistantContent::Reasoning(reasoning) => {
316 reasoning_content.push_str(&reasoning.display_text());
317 }
318 message::AssistantContent::ToolCall(tool_call) => {
319 tool_calls.push(ToolCall::from(tool_call.clone()));
320 }
321 _ => {}
322 }
323 }
324
325 let reasoning = if reasoning_content.is_empty() {
326 None
327 } else {
328 Some(reasoning_content)
329 };
330
331 Ok(vec![Message::Assistant {
332 content: text_content,
333 name: None,
334 tool_calls,
335 reasoning_content: reasoning,
336 }])
337 }
338 }
339 }
340}
341
342#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
343pub struct ToolCall {
344 pub id: String,
345 pub index: usize,
346 #[serde(default)]
347 pub r#type: ToolType,
348 pub function: Function,
349}
350
351#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
352pub struct Function {
353 pub name: String,
354 #[serde(with = "json_utils::stringified_json")]
355 pub arguments: serde_json::Value,
356}
357
358#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
359#[serde(rename_all = "lowercase")]
360pub enum ToolType {
361 #[default]
362 Function,
363}
364
365#[derive(Clone, Debug, Deserialize, Serialize)]
366pub struct ToolDefinition {
367 pub r#type: String,
368 pub function: completion::ToolDefinition,
369}
370
371impl From<crate::completion::ToolDefinition> for ToolDefinition {
372 fn from(tool: crate::completion::ToolDefinition) -> Self {
373 Self {
374 r#type: "function".into(),
375 function: tool,
376 }
377 }
378}
379
380impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
381 type Error = CompletionError;
382
383 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
384 let choice = response.choices.first().ok_or_else(|| {
385 CompletionError::ResponseError("Response contained no choices".to_owned())
386 })?;
387 let content = match &choice.message {
388 Message::Assistant {
389 content,
390 tool_calls,
391 reasoning_content,
392 ..
393 } => {
394 let mut content = if content.trim().is_empty() {
395 vec![]
396 } else {
397 vec![completion::AssistantContent::text(content)]
398 };
399
400 content.extend(
401 tool_calls
402 .iter()
403 .map(|call| {
404 completion::AssistantContent::tool_call(
405 &call.id,
406 &call.function.name,
407 call.function.arguments.clone(),
408 )
409 })
410 .collect::<Vec<_>>(),
411 );
412
413 if let Some(reasoning_content) = reasoning_content {
414 content.push(completion::AssistantContent::reasoning(reasoning_content));
415 }
416
417 Ok(content)
418 }
419 _ => Err(CompletionError::ResponseError(
420 "Response did not contain a valid message or tool call".into(),
421 )),
422 }?;
423
424 let choice = OneOrMany::many(content).map_err(|_| {
425 CompletionError::ResponseError(
426 "Response contained no message or tool call (empty)".to_owned(),
427 )
428 })?;
429
430 let usage = completion::Usage {
431 input_tokens: response.usage.prompt_tokens as u64,
432 output_tokens: response.usage.completion_tokens as u64,
433 total_tokens: response.usage.total_tokens as u64,
434 cached_input_tokens: response
435 .usage
436 .prompt_tokens_details
437 .as_ref()
438 .and_then(|d| d.cached_tokens)
439 .map(|c| c as u64)
440 .unwrap_or(0),
441 };
442
443 Ok(completion::CompletionResponse {
444 choice,
445 usage,
446 raw_response: response,
447 message_id: None,
448 })
449 }
450}
451
452#[derive(Debug, Serialize, Deserialize)]
453pub(super) struct DeepseekCompletionRequest {
454 model: String,
455 pub messages: Vec<Message>,
456 #[serde(skip_serializing_if = "Option::is_none")]
457 temperature: Option<f64>,
458 #[serde(skip_serializing_if = "Vec::is_empty")]
459 tools: Vec<ToolDefinition>,
460 #[serde(skip_serializing_if = "Option::is_none")]
461 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
462 #[serde(flatten, skip_serializing_if = "Option::is_none")]
463 pub additional_params: Option<serde_json::Value>,
464}
465
466impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
467 type Error = CompletionError;
468
469 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
470 if req.output_schema.is_some() {
471 tracing::warn!("Structured outputs currently not supported for DeepSeek");
472 }
473 let model = req.model.clone().unwrap_or_else(|| model.to_string());
474 let mut full_history: Vec<Message> = match &req.preamble {
475 Some(preamble) => vec![Message::system(preamble)],
476 None => vec![],
477 };
478
479 if let Some(docs) = req.normalized_documents() {
480 let docs: Vec<Message> = docs.try_into()?;
481 full_history.extend(docs);
482 }
483
484 let chat_history: Vec<Message> = req
485 .chat_history
486 .clone()
487 .into_iter()
488 .map(|message| message.try_into())
489 .collect::<Result<Vec<Vec<Message>>, _>>()?
490 .into_iter()
491 .flatten()
492 .collect();
493
494 full_history.extend(chat_history);
495
496 let tool_choice = req
497 .tool_choice
498 .clone()
499 .map(crate::providers::openrouter::ToolChoice::try_from)
500 .transpose()?;
501
502 Ok(Self {
503 model: model.to_string(),
504 messages: full_history,
505 temperature: req.temperature,
506 tools: req
507 .tools
508 .clone()
509 .into_iter()
510 .map(ToolDefinition::from)
511 .collect::<Vec<_>>(),
512 tool_choice,
513 additional_params: req.additional_params,
514 })
515 }
516}
517
518#[derive(Clone)]
520pub struct CompletionModel<T = reqwest::Client> {
521 pub client: Client<T>,
522 pub model: String,
523}
524
525impl<T> completion::CompletionModel for CompletionModel<T>
526where
527 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
528{
529 type Response = CompletionResponse;
530 type StreamingResponse = StreamingCompletionResponse;
531
532 type Client = Client<T>;
533
534 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
535 Self {
536 client: client.clone(),
537 model: model.into().to_string(),
538 }
539 }
540
541 async fn completion(
542 &self,
543 completion_request: CompletionRequest,
544 ) -> Result<
545 completion::CompletionResponse<CompletionResponse>,
546 crate::completion::CompletionError,
547 > {
548 let span = if tracing::Span::current().is_disabled() {
549 info_span!(
550 target: "rig::completions",
551 "chat",
552 gen_ai.operation.name = "chat",
553 gen_ai.provider.name = "deepseek",
554 gen_ai.request.model = self.model,
555 gen_ai.system_instructions = tracing::field::Empty,
556 gen_ai.response.id = tracing::field::Empty,
557 gen_ai.response.model = tracing::field::Empty,
558 gen_ai.usage.output_tokens = tracing::field::Empty,
559 gen_ai.usage.input_tokens = tracing::field::Empty,
560 gen_ai.usage.cached_tokens = tracing::field::Empty,
561 )
562 } else {
563 tracing::Span::current()
564 };
565
566 span.record("gen_ai.system_instructions", &completion_request.preamble);
567
568 let request =
569 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
570
571 if enabled!(Level::TRACE) {
572 tracing::trace!(target: "rig::completions",
573 "DeepSeek completion request: {}",
574 serde_json::to_string_pretty(&request)?
575 );
576 }
577
578 let body = serde_json::to_vec(&request)?;
579 let req = self
580 .client
581 .post("/chat/completions")?
582 .body(body)
583 .map_err(|e| CompletionError::HttpError(e.into()))?;
584
585 async move {
586 let response = self.client.send::<_, Bytes>(req).await?;
587 let status = response.status();
588 let response_body = response.into_body().into_future().await?.to_vec();
589
590 if status.is_success() {
591 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
592 ApiResponse::Ok(response) => {
593 let span = tracing::Span::current();
594 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
595 span.record(
596 "gen_ai.usage.output_tokens",
597 response.usage.completion_tokens,
598 );
599 span.record(
600 "gen_ai.usage.cached_tokens",
601 response
602 .usage
603 .prompt_tokens_details
604 .as_ref()
605 .and_then(|d| d.cached_tokens)
606 .unwrap_or(0),
607 );
608 if enabled!(Level::TRACE) {
609 tracing::trace!(target: "rig::completions",
610 "DeepSeek completion response: {}",
611 serde_json::to_string_pretty(&response)?
612 );
613 }
614 response.try_into()
615 }
616 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
617 }
618 } else {
619 Err(CompletionError::ProviderError(
620 String::from_utf8_lossy(&response_body).to_string(),
621 ))
622 }
623 }
624 .instrument(span)
625 .await
626 }
627
628 async fn stream(
629 &self,
630 completion_request: CompletionRequest,
631 ) -> Result<
632 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
633 CompletionError,
634 > {
635 let preamble = completion_request.preamble.clone();
636 let mut request =
637 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
638
639 let params = json_utils::merge(
640 request.additional_params.unwrap_or(serde_json::json!({})),
641 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
642 );
643
644 request.additional_params = Some(params);
645
646 if enabled!(Level::TRACE) {
647 tracing::trace!(target: "rig::completions",
648 "DeepSeek streaming completion request: {}",
649 serde_json::to_string_pretty(&request)?
650 );
651 }
652
653 let body = serde_json::to_vec(&request)?;
654
655 let req = self
656 .client
657 .post("/chat/completions")?
658 .body(body)
659 .map_err(|e| CompletionError::HttpError(e.into()))?;
660
661 let span = if tracing::Span::current().is_disabled() {
662 info_span!(
663 target: "rig::completions",
664 "chat_streaming",
665 gen_ai.operation.name = "chat_streaming",
666 gen_ai.provider.name = "deepseek",
667 gen_ai.request.model = self.model,
668 gen_ai.system_instructions = preamble,
669 gen_ai.response.id = tracing::field::Empty,
670 gen_ai.response.model = tracing::field::Empty,
671 gen_ai.usage.output_tokens = tracing::field::Empty,
672 gen_ai.usage.input_tokens = tracing::field::Empty,
673 gen_ai.usage.cached_tokens = tracing::field::Empty,
674 )
675 } else {
676 tracing::Span::current()
677 };
678
679 tracing::Instrument::instrument(
680 send_compatible_streaming_request(self.client.clone(), req),
681 span,
682 )
683 .await
684 }
685}
686
687#[derive(Deserialize, Debug)]
688pub struct StreamingDelta {
689 #[serde(default)]
690 content: Option<String>,
691 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
692 tool_calls: Vec<StreamingToolCall>,
693 reasoning_content: Option<String>,
694}
695
696#[derive(Deserialize, Debug)]
697struct StreamingChoice {
698 delta: StreamingDelta,
699}
700
701#[derive(Deserialize, Debug)]
702struct StreamingCompletionChunk {
703 choices: Vec<StreamingChoice>,
704 usage: Option<Usage>,
705}
706
707#[derive(Clone, Deserialize, Serialize, Debug)]
708pub struct StreamingCompletionResponse {
709 pub usage: Usage,
710}
711
712impl GetTokenUsage for StreamingCompletionResponse {
713 fn token_usage(&self) -> Option<crate::completion::Usage> {
714 let mut usage = crate::completion::Usage::new();
715 usage.input_tokens = self.usage.prompt_tokens as u64;
716 usage.output_tokens = self.usage.completion_tokens as u64;
717 usage.total_tokens = self.usage.total_tokens as u64;
718 usage.cached_input_tokens = self
719 .usage
720 .prompt_tokens_details
721 .as_ref()
722 .and_then(|d| d.cached_tokens)
723 .map(|c| c as u64)
724 .unwrap_or(0);
725
726 Some(usage)
727 }
728}
729
730pub async fn send_compatible_streaming_request<T>(
731 http_client: T,
732 req: Request<Vec<u8>>,
733) -> Result<
734 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
735 CompletionError,
736>
737where
738 T: HttpClientExt + Clone + 'static,
739{
740 let mut event_source = GenericEventSource::new(http_client, req);
741
742 let stream = stream! {
743 let mut final_usage = Usage::new();
744 let mut text_response = String::new();
745 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
746
747 while let Some(event_result) = event_source.next().await {
748 match event_result {
749 Ok(Event::Open) => {
750 tracing::trace!("SSE connection opened");
751 continue;
752 }
753 Ok(Event::Message(message)) => {
754 if message.data.trim().is_empty() || message.data == "[DONE]" {
755 continue;
756 }
757
758 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
759 let Ok(data) = parsed else {
760 let err = parsed.unwrap_err();
761 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
762 continue;
763 };
764
765 if let Some(choice) = data.choices.first() {
766 let delta = &choice.delta;
767
768 if !delta.tool_calls.is_empty() {
769 for tool_call in &delta.tool_calls {
770 let function = &tool_call.function;
771
772 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
774 && empty_or_none(&function.arguments)
775 {
776 let id = tool_call.id.clone().unwrap_or_default();
777 let name = function.name.clone().unwrap();
778 calls.insert(tool_call.index, (id, name, String::new()));
779 }
780 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
782 && let Some(arguments) = &function.arguments
783 && !arguments.is_empty()
784 {
785 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
786 let combined = format!("{}{}", existing_args, arguments);
787 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
788 } else {
789 tracing::debug!("Partial tool call received but tool call was never started.");
790 }
791 }
792 else {
794 let id = tool_call.id.clone().unwrap_or_default();
795 let name = function.name.clone().unwrap_or_default();
796 let arguments_str = function.arguments.clone().unwrap_or_default();
797
798 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
799 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
800 continue;
801 };
802
803 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
804 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
805 ));
806 }
807 }
808 }
809
810 if let Some(content) = &delta.reasoning_content {
812 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
813 id: None,
814 reasoning: content.to_string()
815 });
816 }
817
818 if let Some(content) = &delta.content {
819 text_response += content;
820 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
821 }
822 }
823
824 if let Some(usage) = data.usage {
825 final_usage = usage.clone();
826 }
827 }
828 Err(crate::http_client::Error::StreamEnded) => {
829 break;
830 }
831 Err(err) => {
832 tracing::error!(?err, "SSE error");
833 yield Err(CompletionError::ResponseError(err.to_string()));
834 break;
835 }
836 }
837 }
838
839 event_source.close();
840
841 let mut tool_calls = Vec::new();
842 for (index, (id, name, arguments)) in calls {
844 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
845 continue;
846 };
847
848 tool_calls.push(ToolCall {
849 id: id.clone(),
850 index,
851 r#type: ToolType::Function,
852 function: Function {
853 name: name.clone(),
854 arguments: arguments_json.clone()
855 }
856 });
857 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
858 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
859 ));
860 }
861
862 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
863 StreamingCompletionResponse { usage: final_usage.clone() }
864 ));
865 };
866
867 Ok(crate::streaming::StreamingCompletionResponse::stream(
868 Box::pin(stream),
869 ))
870}
871
872pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
876pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
877
878#[cfg(test)]
880mod tests {
881 use super::*;
882
883 #[test]
884 fn test_deserialize_vec_choice() {
885 let data = r#"[{
886 "finish_reason": "stop",
887 "index": 0,
888 "logprobs": null,
889 "message":{"role":"assistant","content":"Hello, world!"}
890 }]"#;
891
892 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
893 assert_eq!(choices.len(), 1);
894 match &choices.first().unwrap().message {
895 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
896 _ => panic!("Expected assistant message"),
897 }
898 }
899
900 #[test]
901 fn test_deserialize_deepseek_response() {
902 let data = r#"{
903 "choices":[{
904 "finish_reason": "stop",
905 "index": 0,
906 "logprobs": null,
907 "message":{"role":"assistant","content":"Hello, world!"}
908 }],
909 "usage": {
910 "completion_tokens": 0,
911 "prompt_tokens": 0,
912 "prompt_cache_hit_tokens": 0,
913 "prompt_cache_miss_tokens": 0,
914 "total_tokens": 0
915 }
916 }"#;
917
918 let jd = &mut serde_json::Deserializer::from_str(data);
919 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
920 match result {
921 Ok(response) => match &response.choices.first().unwrap().message {
922 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
923 _ => panic!("Expected assistant message"),
924 },
925 Err(err) => {
926 panic!("Deserialization error at {}: {}", err.path(), err);
927 }
928 }
929 }
930
931 #[test]
932 fn test_deserialize_example_response() {
933 let data = r#"
934 {
935 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
936 "object": "chat.completion",
937 "created": 0,
938 "model": "deepseek-chat",
939 "choices": [
940 {
941 "index": 0,
942 "message": {
943 "role": "assistant",
944 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
945 },
946 "logprobs": null,
947 "finish_reason": "stop"
948 }
949 ],
950 "usage": {
951 "prompt_tokens": 13,
952 "completion_tokens": 32,
953 "total_tokens": 45,
954 "prompt_tokens_details": {
955 "cached_tokens": 0
956 },
957 "prompt_cache_hit_tokens": 0,
958 "prompt_cache_miss_tokens": 13
959 },
960 "system_fingerprint": "fp_4b6881f2c5"
961 }
962 "#;
963 let jd = &mut serde_json::Deserializer::from_str(data);
964 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
965
966 match result {
967 Ok(response) => match &response.choices.first().unwrap().message {
968 Message::Assistant { content, .. } => assert_eq!(
969 content,
970 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
971 ),
972 _ => panic!("Expected assistant message"),
973 },
974 Err(err) => {
975 panic!("Deserialization error at {}: {}", err.path(), err);
976 }
977 }
978 }
979
980 #[test]
981 fn test_serialize_deserialize_tool_call_message() {
982 let tool_call_choice_json = r#"
983 {
984 "finish_reason": "tool_calls",
985 "index": 0,
986 "logprobs": null,
987 "message": {
988 "content": "",
989 "role": "assistant",
990 "tool_calls": [
991 {
992 "function": {
993 "arguments": "{\"x\":2,\"y\":5}",
994 "name": "subtract"
995 },
996 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
997 "index": 0,
998 "type": "function"
999 }
1000 ]
1001 }
1002 }
1003 "#;
1004
1005 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1006
1007 let expected_choice: Choice = Choice {
1008 finish_reason: "tool_calls".to_string(),
1009 index: 0,
1010 logprobs: None,
1011 message: Message::Assistant {
1012 content: "".to_string(),
1013 name: None,
1014 tool_calls: vec![ToolCall {
1015 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1016 function: Function {
1017 name: "subtract".to_string(),
1018 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1019 },
1020 index: 0,
1021 r#type: ToolType::Function,
1022 }],
1023 reasoning_content: None,
1024 },
1025 };
1026
1027 assert_eq!(choice, expected_choice);
1028 }
1029 #[test]
1030 fn test_user_message_multiple_text_items_merged() {
1031 use crate::completion::message::{Message as RigMessage, UserContent};
1032
1033 let rig_msg = RigMessage::User {
1034 content: OneOrMany::many(vec![
1035 UserContent::text("first part"),
1036 UserContent::text("second part"),
1037 ])
1038 .expect("content should not be empty"),
1039 };
1040
1041 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1042
1043 let user_messages: Vec<&Message> = messages
1044 .iter()
1045 .filter(|m| matches!(m, Message::User { .. }))
1046 .collect();
1047
1048 assert_eq!(
1049 user_messages.len(),
1050 1,
1051 "multiple text items should produce a single user message"
1052 );
1053 match &user_messages[0] {
1054 Message::User { content, .. } => {
1055 assert_eq!(content, "first part\nsecond part");
1056 }
1057 _ => unreachable!(),
1058 }
1059 }
1060
1061 #[test]
1062 fn test_assistant_message_with_reasoning_and_tool_calls() {
1063 use crate::completion::message::{AssistantContent, Message as RigMessage};
1064
1065 let rig_msg = RigMessage::Assistant {
1066 id: None,
1067 content: OneOrMany::many(vec![
1068 AssistantContent::reasoning("thinking about the problem"),
1069 AssistantContent::text("I'll call the tool"),
1070 AssistantContent::tool_call(
1071 "call_1",
1072 "subtract",
1073 serde_json::json!({"x": 2, "y": 5}),
1074 ),
1075 ])
1076 .expect("content should not be empty"),
1077 };
1078
1079 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1080
1081 assert_eq!(messages.len(), 1, "should produce exactly one message");
1082 match &messages[0] {
1083 Message::Assistant {
1084 content,
1085 tool_calls,
1086 reasoning_content,
1087 ..
1088 } => {
1089 assert_eq!(content, "I'll call the tool");
1090 assert_eq!(
1091 reasoning_content.as_deref(),
1092 Some("thinking about the problem")
1093 );
1094 assert_eq!(tool_calls.len(), 1);
1095 assert_eq!(tool_calls[0].function.name, "subtract");
1096 }
1097 _ => panic!("Expected assistant message"),
1098 }
1099 }
1100
1101 #[test]
1102 fn test_assistant_message_without_reasoning() {
1103 use crate::completion::message::{AssistantContent, Message as RigMessage};
1104
1105 let rig_msg = RigMessage::Assistant {
1106 id: None,
1107 content: OneOrMany::many(vec![
1108 AssistantContent::text("calling tool"),
1109 AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1110 ])
1111 .expect("content should not be empty"),
1112 };
1113
1114 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1115
1116 assert_eq!(messages.len(), 1);
1117 match &messages[0] {
1118 Message::Assistant {
1119 reasoning_content,
1120 tool_calls,
1121 ..
1122 } => {
1123 assert!(reasoning_content.is_none());
1124 assert_eq!(tool_calls.len(), 1);
1125 }
1126 _ => panic!("Expected assistant message"),
1127 }
1128 }
1129
1130 #[test]
1131 fn test_client_initialization() {
1132 let _client =
1133 crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1134 let _client_from_builder = crate::providers::deepseek::Client::builder()
1135 .api_key("dummy-key")
1136 .build()
1137 .expect("Client::builder() failed");
1138 }
1139}