1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51 const VERIFY_PATH: &'static str = "/user/balance";
52}
53
54impl<H> Capabilities<H> for DeepSeekExt {
55 type Completion = Capable<CompletionModel<H>>;
56 type Embeddings = Nothing;
57 type Transcription = Nothing;
58 type ModelListing = Nothing;
59 #[cfg(feature = "image")]
60 type ImageGeneration = Nothing;
61 #[cfg(feature = "audio")]
62 type AudioGeneration = Nothing;
63}
64
65impl DebugExt for DeepSeekExt {}
66
67impl ProviderBuilder for DeepSeekExtBuilder {
68 type Extension<H>
69 = DeepSeekExt
70 where
71 H: HttpClientExt;
72 type ApiKey = DeepSeekApiKey;
73
74 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
75
76 fn build<H>(
77 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
78 ) -> http_client::Result<Self::Extension<H>>
79 where
80 H: HttpClientExt,
81 {
82 Ok(DeepSeekExt)
83 }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
87pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
88
89impl ProviderClient for Client {
90 type Input = DeepSeekApiKey;
91
92 fn from_env() -> Self {
94 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
95 let mut client_builder = Self::builder();
96 client_builder.headers_mut().insert(
97 http::header::CONTENT_TYPE,
98 http::HeaderValue::from_static("application/json"),
99 );
100 let client_builder = client_builder.api_key(&api_key);
101 client_builder.build().unwrap()
102 }
103
104 fn from_val(input: Self::Input) -> Self {
105 Self::new(input).unwrap()
106 }
107}
108
109#[derive(Debug, Deserialize)]
110struct ApiErrorResponse {
111 message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117 Ok(T),
118 Err(ApiErrorResponse),
119}
120
121impl From<ApiErrorResponse> for CompletionError {
122 fn from(err: ApiErrorResponse) -> Self {
123 CompletionError::ProviderError(err.message)
124 }
125}
126
127#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct CompletionResponse {
130 pub choices: Vec<Choice>,
132 pub usage: Usage,
133 }
135
136#[derive(Clone, Debug, Serialize, Deserialize, Default)]
137pub struct Usage {
138 pub completion_tokens: u32,
139 pub prompt_tokens: u32,
140 pub prompt_cache_hit_tokens: u32,
141 pub prompt_cache_miss_tokens: u32,
142 pub total_tokens: u32,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub completion_tokens_details: Option<CompletionTokensDetails>,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 pub prompt_tokens_details: Option<PromptTokensDetails>,
147}
148
149impl Usage {
150 fn new() -> Self {
151 Self {
152 completion_tokens: 0,
153 prompt_tokens: 0,
154 prompt_cache_hit_tokens: 0,
155 prompt_cache_miss_tokens: 0,
156 total_tokens: 0,
157 completion_tokens_details: None,
158 prompt_tokens_details: None,
159 }
160 }
161}
162
163#[derive(Clone, Debug, Serialize, Deserialize, Default)]
164pub struct CompletionTokensDetails {
165 #[serde(skip_serializing_if = "Option::is_none")]
166 pub reasoning_tokens: Option<u32>,
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default)]
170pub struct PromptTokensDetails {
171 #[serde(skip_serializing_if = "Option::is_none")]
172 pub cached_tokens: Option<u32>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
176pub struct Choice {
177 pub index: usize,
178 pub message: Message,
179 pub logprobs: Option<serde_json::Value>,
180 pub finish_reason: String,
181}
182
183#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
184#[serde(tag = "role", rename_all = "lowercase")]
185pub enum Message {
186 System {
187 content: String,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 name: Option<String>,
190 },
191 User {
192 content: String,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 name: Option<String>,
195 },
196 Assistant {
197 content: String,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 name: Option<String>,
200 #[serde(
201 default,
202 deserialize_with = "json_utils::null_or_vec",
203 skip_serializing_if = "Vec::is_empty"
204 )]
205 tool_calls: Vec<ToolCall>,
206 #[serde(skip_serializing_if = "Option::is_none")]
208 reasoning_content: Option<String>,
209 },
210 #[serde(rename = "tool")]
211 ToolResult {
212 tool_call_id: String,
213 content: String,
214 },
215}
216
217impl Message {
218 pub fn system(content: &str) -> Self {
219 Message::System {
220 content: content.to_owned(),
221 name: None,
222 }
223 }
224}
225
226impl From<message::ToolResult> for Message {
227 fn from(tool_result: message::ToolResult) -> Self {
228 let content = match tool_result.content.first() {
229 message::ToolResultContent::Text(text) => text.text,
230 message::ToolResultContent::Image(_) => String::from("[Image]"),
231 };
232
233 Message::ToolResult {
234 tool_call_id: tool_result.id,
235 content,
236 }
237 }
238}
239
240impl From<message::ToolCall> for ToolCall {
241 fn from(tool_call: message::ToolCall) -> Self {
242 Self {
243 id: tool_call.id,
244 index: 0,
246 r#type: ToolType::Function,
247 function: Function {
248 name: tool_call.function.name,
249 arguments: tool_call.function.arguments,
250 },
251 }
252 }
253}
254
255impl TryFrom<message::Message> for Vec<Message> {
256 type Error = message::MessageError;
257
258 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
259 match message {
260 message::Message::System { content } => Ok(vec![Message::System {
261 content,
262 name: None,
263 }]),
264 message::Message::User { content } => {
265 let mut messages = vec![];
267
268 let tool_results = content
269 .clone()
270 .into_iter()
271 .filter_map(|content| match content {
272 message::UserContent::ToolResult(tool_result) => {
273 Some(Message::from(tool_result))
274 }
275 _ => None,
276 })
277 .collect::<Vec<_>>();
278
279 messages.extend(tool_results);
280
281 let text_content: String = content
282 .into_iter()
283 .filter_map(|content| match content {
284 message::UserContent::Text(text) => Some(text.text),
285 message::UserContent::Document(Document {
286 data:
287 DocumentSourceKind::Base64(content)
288 | DocumentSourceKind::String(content),
289 ..
290 }) => Some(content),
291 _ => None,
292 })
293 .collect::<Vec<_>>()
294 .join("\n");
295
296 if !text_content.is_empty() {
297 messages.push(Message::User {
298 content: text_content,
299 name: None,
300 });
301 }
302
303 Ok(messages)
304 }
305 message::Message::Assistant { content, .. } => {
306 let mut text_content = String::new();
307 let mut reasoning_content = String::new();
308 let mut tool_calls = Vec::new();
309
310 for item in content.iter() {
311 match item {
312 message::AssistantContent::Text(text) => {
313 text_content.push_str(text.text());
314 }
315 message::AssistantContent::Reasoning(reasoning) => {
316 reasoning_content.push_str(&reasoning.display_text());
317 }
318 message::AssistantContent::ToolCall(tool_call) => {
319 tool_calls.push(ToolCall::from(tool_call.clone()));
320 }
321 _ => {}
322 }
323 }
324
325 let reasoning = if reasoning_content.is_empty() {
326 None
327 } else {
328 Some(reasoning_content)
329 };
330
331 Ok(vec![Message::Assistant {
332 content: text_content,
333 name: None,
334 tool_calls,
335 reasoning_content: reasoning,
336 }])
337 }
338 }
339 }
340}
341
342#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
343pub struct ToolCall {
344 pub id: String,
345 pub index: usize,
346 #[serde(default)]
347 pub r#type: ToolType,
348 pub function: Function,
349}
350
351#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
352pub struct Function {
353 pub name: String,
354 #[serde(with = "json_utils::stringified_json")]
355 pub arguments: serde_json::Value,
356}
357
358#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
359#[serde(rename_all = "lowercase")]
360pub enum ToolType {
361 #[default]
362 Function,
363}
364
365#[derive(Clone, Debug, Deserialize, Serialize)]
366pub struct ToolDefinition {
367 pub r#type: String,
368 pub function: completion::ToolDefinition,
369}
370
371impl From<crate::completion::ToolDefinition> for ToolDefinition {
372 fn from(tool: crate::completion::ToolDefinition) -> Self {
373 Self {
374 r#type: "function".into(),
375 function: tool,
376 }
377 }
378}
379
380impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
381 type Error = CompletionError;
382
383 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
384 let choice = response.choices.first().ok_or_else(|| {
385 CompletionError::ResponseError("Response contained no choices".to_owned())
386 })?;
387 let content = match &choice.message {
388 Message::Assistant {
389 content,
390 tool_calls,
391 reasoning_content,
392 ..
393 } => {
394 let mut content = if content.trim().is_empty() {
395 vec![]
396 } else {
397 vec![completion::AssistantContent::text(content)]
398 };
399
400 content.extend(
401 tool_calls
402 .iter()
403 .map(|call| {
404 completion::AssistantContent::tool_call(
405 &call.id,
406 &call.function.name,
407 call.function.arguments.clone(),
408 )
409 })
410 .collect::<Vec<_>>(),
411 );
412
413 if let Some(reasoning_content) = reasoning_content {
414 content.push(completion::AssistantContent::reasoning(reasoning_content));
415 }
416
417 Ok(content)
418 }
419 _ => Err(CompletionError::ResponseError(
420 "Response did not contain a valid message or tool call".into(),
421 )),
422 }?;
423
424 let choice = OneOrMany::many(content).map_err(|_| {
425 CompletionError::ResponseError(
426 "Response contained no message or tool call (empty)".to_owned(),
427 )
428 })?;
429
430 let usage = completion::Usage {
431 input_tokens: response.usage.prompt_tokens as u64,
432 output_tokens: response.usage.completion_tokens as u64,
433 total_tokens: response.usage.total_tokens as u64,
434 cached_input_tokens: response
435 .usage
436 .prompt_tokens_details
437 .as_ref()
438 .and_then(|d| d.cached_tokens)
439 .map(|c| c as u64)
440 .unwrap_or(0),
441 cache_creation_input_tokens: 0,
442 };
443
444 Ok(completion::CompletionResponse {
445 choice,
446 usage,
447 raw_response: response,
448 message_id: None,
449 })
450 }
451}
452
453#[derive(Debug, Serialize, Deserialize)]
454pub(super) struct DeepseekCompletionRequest {
455 model: String,
456 pub messages: Vec<Message>,
457 #[serde(skip_serializing_if = "Option::is_none")]
458 temperature: Option<f64>,
459 #[serde(skip_serializing_if = "Vec::is_empty")]
460 tools: Vec<ToolDefinition>,
461 #[serde(skip_serializing_if = "Option::is_none")]
462 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
463 #[serde(flatten, skip_serializing_if = "Option::is_none")]
464 pub additional_params: Option<serde_json::Value>,
465}
466
467impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
468 type Error = CompletionError;
469
470 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
471 if req.output_schema.is_some() {
472 tracing::warn!("Structured outputs currently not supported for DeepSeek");
473 }
474 let model = req.model.clone().unwrap_or_else(|| model.to_string());
475 let mut full_history: Vec<Message> = match &req.preamble {
476 Some(preamble) => vec![Message::system(preamble)],
477 None => vec![],
478 };
479
480 if let Some(docs) = req.normalized_documents() {
481 let docs: Vec<Message> = docs.try_into()?;
482 full_history.extend(docs);
483 }
484
485 let chat_history: Vec<Message> = req
486 .chat_history
487 .clone()
488 .into_iter()
489 .map(|message| message.try_into())
490 .collect::<Result<Vec<Vec<Message>>, _>>()?
491 .into_iter()
492 .flatten()
493 .collect();
494
495 full_history.extend(chat_history);
496
497 let tool_choice = req
498 .tool_choice
499 .clone()
500 .map(crate::providers::openrouter::ToolChoice::try_from)
501 .transpose()?;
502
503 Ok(Self {
504 model: model.to_string(),
505 messages: full_history,
506 temperature: req.temperature,
507 tools: req
508 .tools
509 .clone()
510 .into_iter()
511 .map(ToolDefinition::from)
512 .collect::<Vec<_>>(),
513 tool_choice,
514 additional_params: req.additional_params,
515 })
516 }
517}
518
519#[derive(Clone)]
521pub struct CompletionModel<T = reqwest::Client> {
522 pub client: Client<T>,
523 pub model: String,
524}
525
526impl<T> completion::CompletionModel for CompletionModel<T>
527where
528 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
529{
530 type Response = CompletionResponse;
531 type StreamingResponse = StreamingCompletionResponse;
532
533 type Client = Client<T>;
534
535 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
536 Self {
537 client: client.clone(),
538 model: model.into().to_string(),
539 }
540 }
541
542 async fn completion(
543 &self,
544 completion_request: CompletionRequest,
545 ) -> Result<
546 completion::CompletionResponse<CompletionResponse>,
547 crate::completion::CompletionError,
548 > {
549 let span = if tracing::Span::current().is_disabled() {
550 info_span!(
551 target: "rig::completions",
552 "chat",
553 gen_ai.operation.name = "chat",
554 gen_ai.provider.name = "deepseek",
555 gen_ai.request.model = self.model,
556 gen_ai.system_instructions = tracing::field::Empty,
557 gen_ai.response.id = tracing::field::Empty,
558 gen_ai.response.model = tracing::field::Empty,
559 gen_ai.usage.output_tokens = tracing::field::Empty,
560 gen_ai.usage.input_tokens = tracing::field::Empty,
561 gen_ai.usage.cached_tokens = tracing::field::Empty,
562 )
563 } else {
564 tracing::Span::current()
565 };
566
567 span.record("gen_ai.system_instructions", &completion_request.preamble);
568
569 let request =
570 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
571
572 if enabled!(Level::TRACE) {
573 tracing::trace!(target: "rig::completions",
574 "DeepSeek completion request: {}",
575 serde_json::to_string_pretty(&request)?
576 );
577 }
578
579 let body = serde_json::to_vec(&request)?;
580 let req = self
581 .client
582 .post("/chat/completions")?
583 .body(body)
584 .map_err(|e| CompletionError::HttpError(e.into()))?;
585
586 async move {
587 let response = self.client.send::<_, Bytes>(req).await?;
588 let status = response.status();
589 let response_body = response.into_body().into_future().await?.to_vec();
590
591 if status.is_success() {
592 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
593 ApiResponse::Ok(response) => {
594 let span = tracing::Span::current();
595 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
596 span.record(
597 "gen_ai.usage.output_tokens",
598 response.usage.completion_tokens,
599 );
600 span.record(
601 "gen_ai.usage.cached_tokens",
602 response
603 .usage
604 .prompt_tokens_details
605 .as_ref()
606 .and_then(|d| d.cached_tokens)
607 .unwrap_or(0),
608 );
609 if enabled!(Level::TRACE) {
610 tracing::trace!(target: "rig::completions",
611 "DeepSeek completion response: {}",
612 serde_json::to_string_pretty(&response)?
613 );
614 }
615 response.try_into()
616 }
617 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
618 }
619 } else {
620 Err(CompletionError::ProviderError(
621 String::from_utf8_lossy(&response_body).to_string(),
622 ))
623 }
624 }
625 .instrument(span)
626 .await
627 }
628
629 async fn stream(
630 &self,
631 completion_request: CompletionRequest,
632 ) -> Result<
633 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
634 CompletionError,
635 > {
636 let preamble = completion_request.preamble.clone();
637 let mut request =
638 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
639
640 let params = json_utils::merge(
641 request.additional_params.unwrap_or(serde_json::json!({})),
642 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
643 );
644
645 request.additional_params = Some(params);
646
647 if enabled!(Level::TRACE) {
648 tracing::trace!(target: "rig::completions",
649 "DeepSeek streaming completion request: {}",
650 serde_json::to_string_pretty(&request)?
651 );
652 }
653
654 let body = serde_json::to_vec(&request)?;
655
656 let req = self
657 .client
658 .post("/chat/completions")?
659 .body(body)
660 .map_err(|e| CompletionError::HttpError(e.into()))?;
661
662 let span = if tracing::Span::current().is_disabled() {
663 info_span!(
664 target: "rig::completions",
665 "chat_streaming",
666 gen_ai.operation.name = "chat_streaming",
667 gen_ai.provider.name = "deepseek",
668 gen_ai.request.model = self.model,
669 gen_ai.system_instructions = preamble,
670 gen_ai.response.id = tracing::field::Empty,
671 gen_ai.response.model = tracing::field::Empty,
672 gen_ai.usage.output_tokens = tracing::field::Empty,
673 gen_ai.usage.input_tokens = tracing::field::Empty,
674 gen_ai.usage.cached_tokens = tracing::field::Empty,
675 )
676 } else {
677 tracing::Span::current()
678 };
679
680 tracing::Instrument::instrument(
681 send_compatible_streaming_request(self.client.clone(), req),
682 span,
683 )
684 .await
685 }
686}
687
688#[derive(Deserialize, Debug)]
689pub struct StreamingDelta {
690 #[serde(default)]
691 content: Option<String>,
692 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
693 tool_calls: Vec<StreamingToolCall>,
694 reasoning_content: Option<String>,
695}
696
697#[derive(Deserialize, Debug)]
698struct StreamingChoice {
699 delta: StreamingDelta,
700}
701
702#[derive(Deserialize, Debug)]
703struct StreamingCompletionChunk {
704 choices: Vec<StreamingChoice>,
705 usage: Option<Usage>,
706}
707
708#[derive(Clone, Deserialize, Serialize, Debug)]
709pub struct StreamingCompletionResponse {
710 pub usage: Usage,
711}
712
713impl GetTokenUsage for StreamingCompletionResponse {
714 fn token_usage(&self) -> Option<crate::completion::Usage> {
715 let mut usage = crate::completion::Usage::new();
716 usage.input_tokens = self.usage.prompt_tokens as u64;
717 usage.output_tokens = self.usage.completion_tokens as u64;
718 usage.total_tokens = self.usage.total_tokens as u64;
719 usage.cached_input_tokens = self
720 .usage
721 .prompt_tokens_details
722 .as_ref()
723 .and_then(|d| d.cached_tokens)
724 .map(|c| c as u64)
725 .unwrap_or(0);
726
727 Some(usage)
728 }
729}
730
731pub async fn send_compatible_streaming_request<T>(
732 http_client: T,
733 req: Request<Vec<u8>>,
734) -> Result<
735 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
736 CompletionError,
737>
738where
739 T: HttpClientExt + Clone + 'static,
740{
741 let mut event_source = GenericEventSource::new(http_client, req);
742
743 let stream = stream! {
744 let mut final_usage = Usage::new();
745 let mut text_response = String::new();
746 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
747
748 while let Some(event_result) = event_source.next().await {
749 match event_result {
750 Ok(Event::Open) => {
751 tracing::trace!("SSE connection opened");
752 continue;
753 }
754 Ok(Event::Message(message)) => {
755 if message.data.trim().is_empty() || message.data == "[DONE]" {
756 continue;
757 }
758
759 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
760 let Ok(data) = parsed else {
761 let err = parsed.unwrap_err();
762 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
763 continue;
764 };
765
766 if let Some(choice) = data.choices.first() {
767 let delta = &choice.delta;
768
769 if !delta.tool_calls.is_empty() {
770 for tool_call in &delta.tool_calls {
771 let function = &tool_call.function;
772
773 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
775 && empty_or_none(&function.arguments)
776 {
777 let id = tool_call.id.clone().unwrap_or_default();
778 let name = function.name.clone().unwrap();
779 calls.insert(tool_call.index, (id, name, String::new()));
780 }
781 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
783 && let Some(arguments) = &function.arguments
784 && !arguments.is_empty()
785 {
786 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
787 let combined = format!("{}{}", existing_args, arguments);
788 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
789 } else {
790 tracing::debug!("Partial tool call received but tool call was never started.");
791 }
792 }
793 else {
795 let id = tool_call.id.clone().unwrap_or_default();
796 let name = function.name.clone().unwrap_or_default();
797 let arguments_str = function.arguments.clone().unwrap_or_default();
798
799 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
800 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
801 continue;
802 };
803
804 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
805 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
806 ));
807 }
808 }
809 }
810
811 if let Some(content) = &delta.reasoning_content {
813 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
814 id: None,
815 reasoning: content.to_string()
816 });
817 }
818
819 if let Some(content) = &delta.content {
820 text_response += content;
821 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
822 }
823 }
824
825 if let Some(usage) = data.usage {
826 final_usage = usage.clone();
827 }
828 }
829 Err(crate::http_client::Error::StreamEnded) => {
830 break;
831 }
832 Err(err) => {
833 tracing::error!(?err, "SSE error");
834 yield Err(CompletionError::ResponseError(err.to_string()));
835 break;
836 }
837 }
838 }
839
840 event_source.close();
841
842 let mut tool_calls = Vec::new();
843 for (index, (id, name, arguments)) in calls {
845 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
846 continue;
847 };
848
849 tool_calls.push(ToolCall {
850 id: id.clone(),
851 index,
852 r#type: ToolType::Function,
853 function: Function {
854 name: name.clone(),
855 arguments: arguments_json.clone()
856 }
857 });
858 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
859 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
860 ));
861 }
862
863 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
864 StreamingCompletionResponse { usage: final_usage.clone() }
865 ));
866 };
867
868 Ok(crate::streaming::StreamingCompletionResponse::stream(
869 Box::pin(stream),
870 ))
871}
872
873pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
877pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
878
879#[cfg(test)]
881mod tests {
882 use super::*;
883
884 #[test]
885 fn test_deserialize_vec_choice() {
886 let data = r#"[{
887 "finish_reason": "stop",
888 "index": 0,
889 "logprobs": null,
890 "message":{"role":"assistant","content":"Hello, world!"}
891 }]"#;
892
893 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
894 assert_eq!(choices.len(), 1);
895 match &choices.first().unwrap().message {
896 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
897 _ => panic!("Expected assistant message"),
898 }
899 }
900
901 #[test]
902 fn test_deserialize_deepseek_response() {
903 let data = r#"{
904 "choices":[{
905 "finish_reason": "stop",
906 "index": 0,
907 "logprobs": null,
908 "message":{"role":"assistant","content":"Hello, world!"}
909 }],
910 "usage": {
911 "completion_tokens": 0,
912 "prompt_tokens": 0,
913 "prompt_cache_hit_tokens": 0,
914 "prompt_cache_miss_tokens": 0,
915 "total_tokens": 0
916 }
917 }"#;
918
919 let jd = &mut serde_json::Deserializer::from_str(data);
920 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
921 match result {
922 Ok(response) => match &response.choices.first().unwrap().message {
923 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
924 _ => panic!("Expected assistant message"),
925 },
926 Err(err) => {
927 panic!("Deserialization error at {}: {}", err.path(), err);
928 }
929 }
930 }
931
932 #[test]
933 fn test_deserialize_example_response() {
934 let data = r#"
935 {
936 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
937 "object": "chat.completion",
938 "created": 0,
939 "model": "deepseek-chat",
940 "choices": [
941 {
942 "index": 0,
943 "message": {
944 "role": "assistant",
945 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
946 },
947 "logprobs": null,
948 "finish_reason": "stop"
949 }
950 ],
951 "usage": {
952 "prompt_tokens": 13,
953 "completion_tokens": 32,
954 "total_tokens": 45,
955 "prompt_tokens_details": {
956 "cached_tokens": 0
957 },
958 "prompt_cache_hit_tokens": 0,
959 "prompt_cache_miss_tokens": 13
960 },
961 "system_fingerprint": "fp_4b6881f2c5"
962 }
963 "#;
964 let jd = &mut serde_json::Deserializer::from_str(data);
965 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
966
967 match result {
968 Ok(response) => match &response.choices.first().unwrap().message {
969 Message::Assistant { content, .. } => assert_eq!(
970 content,
971 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
972 ),
973 _ => panic!("Expected assistant message"),
974 },
975 Err(err) => {
976 panic!("Deserialization error at {}: {}", err.path(), err);
977 }
978 }
979 }
980
981 #[test]
982 fn test_serialize_deserialize_tool_call_message() {
983 let tool_call_choice_json = r#"
984 {
985 "finish_reason": "tool_calls",
986 "index": 0,
987 "logprobs": null,
988 "message": {
989 "content": "",
990 "role": "assistant",
991 "tool_calls": [
992 {
993 "function": {
994 "arguments": "{\"x\":2,\"y\":5}",
995 "name": "subtract"
996 },
997 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
998 "index": 0,
999 "type": "function"
1000 }
1001 ]
1002 }
1003 }
1004 "#;
1005
1006 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1007
1008 let expected_choice: Choice = Choice {
1009 finish_reason: "tool_calls".to_string(),
1010 index: 0,
1011 logprobs: None,
1012 message: Message::Assistant {
1013 content: "".to_string(),
1014 name: None,
1015 tool_calls: vec![ToolCall {
1016 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1017 function: Function {
1018 name: "subtract".to_string(),
1019 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1020 },
1021 index: 0,
1022 r#type: ToolType::Function,
1023 }],
1024 reasoning_content: None,
1025 },
1026 };
1027
1028 assert_eq!(choice, expected_choice);
1029 }
1030 #[test]
1031 fn test_user_message_multiple_text_items_merged() {
1032 use crate::completion::message::{Message as RigMessage, UserContent};
1033
1034 let rig_msg = RigMessage::User {
1035 content: OneOrMany::many(vec![
1036 UserContent::text("first part"),
1037 UserContent::text("second part"),
1038 ])
1039 .expect("content should not be empty"),
1040 };
1041
1042 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1043
1044 let user_messages: Vec<&Message> = messages
1045 .iter()
1046 .filter(|m| matches!(m, Message::User { .. }))
1047 .collect();
1048
1049 assert_eq!(
1050 user_messages.len(),
1051 1,
1052 "multiple text items should produce a single user message"
1053 );
1054 match &user_messages[0] {
1055 Message::User { content, .. } => {
1056 assert_eq!(content, "first part\nsecond part");
1057 }
1058 _ => unreachable!(),
1059 }
1060 }
1061
1062 #[test]
1063 fn test_assistant_message_with_reasoning_and_tool_calls() {
1064 use crate::completion::message::{AssistantContent, Message as RigMessage};
1065
1066 let rig_msg = RigMessage::Assistant {
1067 id: None,
1068 content: OneOrMany::many(vec![
1069 AssistantContent::reasoning("thinking about the problem"),
1070 AssistantContent::text("I'll call the tool"),
1071 AssistantContent::tool_call(
1072 "call_1",
1073 "subtract",
1074 serde_json::json!({"x": 2, "y": 5}),
1075 ),
1076 ])
1077 .expect("content should not be empty"),
1078 };
1079
1080 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1081
1082 assert_eq!(messages.len(), 1, "should produce exactly one message");
1083 match &messages[0] {
1084 Message::Assistant {
1085 content,
1086 tool_calls,
1087 reasoning_content,
1088 ..
1089 } => {
1090 assert_eq!(content, "I'll call the tool");
1091 assert_eq!(
1092 reasoning_content.as_deref(),
1093 Some("thinking about the problem")
1094 );
1095 assert_eq!(tool_calls.len(), 1);
1096 assert_eq!(tool_calls[0].function.name, "subtract");
1097 }
1098 _ => panic!("Expected assistant message"),
1099 }
1100 }
1101
1102 #[test]
1103 fn test_assistant_message_without_reasoning() {
1104 use crate::completion::message::{AssistantContent, Message as RigMessage};
1105
1106 let rig_msg = RigMessage::Assistant {
1107 id: None,
1108 content: OneOrMany::many(vec![
1109 AssistantContent::text("calling tool"),
1110 AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1111 ])
1112 .expect("content should not be empty"),
1113 };
1114
1115 let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1116
1117 assert_eq!(messages.len(), 1);
1118 match &messages[0] {
1119 Message::Assistant {
1120 reasoning_content,
1121 tool_calls,
1122 ..
1123 } => {
1124 assert!(reasoning_content.is_none());
1125 assert_eq!(tool_calls.len(), 1);
1126 }
1127 _ => panic!("Expected assistant message"),
1128 }
1129 }
1130
1131 #[test]
1132 fn test_client_initialization() {
1133 let _client =
1134 crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1135 let _client_from_builder = crate::providers::deepseek::Client::builder()
1136 .api_key("dummy-key")
1137 .build()
1138 .expect("Client::builder() failed");
1139 }
1140}