1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51
52 const VERIFY_PATH: &'static str = "/user/balance";
53
54 fn build<H>(
55 _: &crate::client::ClientBuilder<
56 Self::Builder,
57 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58 H,
59 >,
60 ) -> http_client::Result<Self> {
61 Ok(Self)
62 }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66 type Completion = Capable<CompletionModel<H>>;
67 type Embeddings = Nothing;
68 type Transcription = Nothing;
69 type ModelListing = Nothing;
70 #[cfg(feature = "image")]
71 type ImageGeneration = Nothing;
72 #[cfg(feature = "audio")]
73 type AudioGeneration = Nothing;
74}
75
76impl DebugExt for DeepSeekExt {}
77
78impl ProviderBuilder for DeepSeekExtBuilder {
79 type Output = DeepSeekExt;
80 type ApiKey = DeepSeekApiKey;
81
82 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
83}
84
85pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
86pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
87
88impl ProviderClient for Client {
89 type Input = DeepSeekApiKey;
90
91 fn from_env() -> Self {
93 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
94 let mut client_builder = Self::builder();
95 client_builder.headers_mut().insert(
96 http::header::CONTENT_TYPE,
97 http::HeaderValue::from_static("application/json"),
98 );
99 let client_builder = client_builder.api_key(&api_key);
100 client_builder.build().unwrap()
101 }
102
103 fn from_val(input: Self::Input) -> Self {
104 Self::new(input).unwrap()
105 }
106}
107
108#[derive(Debug, Deserialize)]
109struct ApiErrorResponse {
110 message: String,
111}
112
113#[derive(Debug, Deserialize)]
114#[serde(untagged)]
115enum ApiResponse<T> {
116 Ok(T),
117 Err(ApiErrorResponse),
118}
119
120impl From<ApiErrorResponse> for CompletionError {
121 fn from(err: ApiErrorResponse) -> Self {
122 CompletionError::ProviderError(err.message)
123 }
124}
125
126#[derive(Clone, Debug, Serialize, Deserialize)]
128pub struct CompletionResponse {
129 pub choices: Vec<Choice>,
131 pub usage: Usage,
132 }
134
135#[derive(Clone, Debug, Serialize, Deserialize, Default)]
136pub struct Usage {
137 pub completion_tokens: u32,
138 pub prompt_tokens: u32,
139 pub prompt_cache_hit_tokens: u32,
140 pub prompt_cache_miss_tokens: u32,
141 pub total_tokens: u32,
142 #[serde(skip_serializing_if = "Option::is_none")]
143 pub completion_tokens_details: Option<CompletionTokensDetails>,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 pub prompt_tokens_details: Option<PromptTokensDetails>,
146}
147
148impl Usage {
149 fn new() -> Self {
150 Self {
151 completion_tokens: 0,
152 prompt_tokens: 0,
153 prompt_cache_hit_tokens: 0,
154 prompt_cache_miss_tokens: 0,
155 total_tokens: 0,
156 completion_tokens_details: None,
157 prompt_tokens_details: None,
158 }
159 }
160}
161
162#[derive(Clone, Debug, Serialize, Deserialize, Default)]
163pub struct CompletionTokensDetails {
164 #[serde(skip_serializing_if = "Option::is_none")]
165 pub reasoning_tokens: Option<u32>,
166}
167
168#[derive(Clone, Debug, Serialize, Deserialize, Default)]
169pub struct PromptTokensDetails {
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub cached_tokens: Option<u32>,
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
175pub struct Choice {
176 pub index: usize,
177 pub message: Message,
178 pub logprobs: Option<serde_json::Value>,
179 pub finish_reason: String,
180}
181
182#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
183#[serde(tag = "role", rename_all = "lowercase")]
184pub enum Message {
185 System {
186 content: String,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 name: Option<String>,
189 },
190 User {
191 content: String,
192 #[serde(skip_serializing_if = "Option::is_none")]
193 name: Option<String>,
194 },
195 Assistant {
196 content: String,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 name: Option<String>,
199 #[serde(
200 default,
201 deserialize_with = "json_utils::null_or_vec",
202 skip_serializing_if = "Vec::is_empty"
203 )]
204 tool_calls: Vec<ToolCall>,
205 #[serde(skip_serializing_if = "Option::is_none")]
207 reasoning_content: Option<String>,
208 },
209 #[serde(rename = "tool")]
210 ToolResult {
211 tool_call_id: String,
212 content: String,
213 },
214}
215
216impl Message {
217 pub fn system(content: &str) -> Self {
218 Message::System {
219 content: content.to_owned(),
220 name: None,
221 }
222 }
223}
224
225impl From<message::ToolResult> for Message {
226 fn from(tool_result: message::ToolResult) -> Self {
227 let content = match tool_result.content.first() {
228 message::ToolResultContent::Text(text) => text.text,
229 message::ToolResultContent::Image(_) => String::from("[Image]"),
230 };
231
232 Message::ToolResult {
233 tool_call_id: tool_result.id,
234 content,
235 }
236 }
237}
238
239impl From<message::ToolCall> for ToolCall {
240 fn from(tool_call: message::ToolCall) -> Self {
241 Self {
242 id: tool_call.id,
243 index: 0,
245 r#type: ToolType::Function,
246 function: Function {
247 name: tool_call.function.name,
248 arguments: tool_call.function.arguments,
249 },
250 }
251 }
252}
253
254impl TryFrom<message::Message> for Vec<Message> {
255 type Error = message::MessageError;
256
257 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
258 match message {
259 message::Message::User { content } => {
260 let mut messages = vec![];
262
263 let tool_results = content
264 .clone()
265 .into_iter()
266 .filter_map(|content| match content {
267 message::UserContent::ToolResult(tool_result) => {
268 Some(Message::from(tool_result))
269 }
270 _ => None,
271 })
272 .collect::<Vec<_>>();
273
274 messages.extend(tool_results);
275
276 let text_messages = content
278 .into_iter()
279 .filter_map(|content| match content {
280 message::UserContent::Text(text) => Some(Message::User {
281 content: text.text,
282 name: None,
283 }),
284 message::UserContent::Document(Document {
285 data:
286 DocumentSourceKind::Base64(content)
287 | DocumentSourceKind::String(content),
288 ..
289 }) => Some(Message::User {
290 content,
291 name: None,
292 }),
293 _ => None,
294 })
295 .collect::<Vec<_>>();
296 messages.extend(text_messages);
297
298 Ok(messages)
299 }
300 message::Message::Assistant { content, .. } => {
301 let mut messages: Vec<Message> = vec![];
302 let mut text_content = String::new();
303 let mut reasoning_content = String::new();
304
305 content.iter().for_each(|content| match content {
306 message::AssistantContent::Text(text) => {
307 text_content.push_str(text.text());
308 }
309 message::AssistantContent::Reasoning(reasoning) => {
310 reasoning_content.push_str(&reasoning.display_text());
311 }
312 _ => {}
313 });
314
315 messages.push(Message::Assistant {
316 content: text_content,
317 name: None,
318 tool_calls: vec![],
319 reasoning_content: if reasoning_content.is_empty() {
320 None
321 } else {
322 Some(reasoning_content)
323 },
324 });
325
326 let tool_calls = content
328 .clone()
329 .into_iter()
330 .filter_map(|content| match content {
331 message::AssistantContent::ToolCall(tool_call) => {
332 Some(ToolCall::from(tool_call))
333 }
334 _ => None,
335 })
336 .collect::<Vec<_>>();
337
338 if !tool_calls.is_empty() {
340 messages.push(Message::Assistant {
341 content: "".to_string(),
342 name: None,
343 tool_calls,
344 reasoning_content: None,
345 });
346 }
347
348 Ok(messages)
349 }
350 }
351 }
352}
353
354#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
355pub struct ToolCall {
356 pub id: String,
357 pub index: usize,
358 #[serde(default)]
359 pub r#type: ToolType,
360 pub function: Function,
361}
362
363#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
364pub struct Function {
365 pub name: String,
366 #[serde(with = "json_utils::stringified_json")]
367 pub arguments: serde_json::Value,
368}
369
370#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
371#[serde(rename_all = "lowercase")]
372pub enum ToolType {
373 #[default]
374 Function,
375}
376
377#[derive(Clone, Debug, Deserialize, Serialize)]
378pub struct ToolDefinition {
379 pub r#type: String,
380 pub function: completion::ToolDefinition,
381}
382
383impl From<crate::completion::ToolDefinition> for ToolDefinition {
384 fn from(tool: crate::completion::ToolDefinition) -> Self {
385 Self {
386 r#type: "function".into(),
387 function: tool,
388 }
389 }
390}
391
392impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
393 type Error = CompletionError;
394
395 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
396 let choice = response.choices.first().ok_or_else(|| {
397 CompletionError::ResponseError("Response contained no choices".to_owned())
398 })?;
399 let content = match &choice.message {
400 Message::Assistant {
401 content,
402 tool_calls,
403 reasoning_content,
404 ..
405 } => {
406 let mut content = if content.trim().is_empty() {
407 vec![]
408 } else {
409 vec![completion::AssistantContent::text(content)]
410 };
411
412 content.extend(
413 tool_calls
414 .iter()
415 .map(|call| {
416 completion::AssistantContent::tool_call(
417 &call.id,
418 &call.function.name,
419 call.function.arguments.clone(),
420 )
421 })
422 .collect::<Vec<_>>(),
423 );
424
425 if let Some(reasoning_content) = reasoning_content {
426 content.push(completion::AssistantContent::reasoning(reasoning_content));
427 }
428
429 Ok(content)
430 }
431 _ => Err(CompletionError::ResponseError(
432 "Response did not contain a valid message or tool call".into(),
433 )),
434 }?;
435
436 let choice = OneOrMany::many(content).map_err(|_| {
437 CompletionError::ResponseError(
438 "Response contained no message or tool call (empty)".to_owned(),
439 )
440 })?;
441
442 let usage = completion::Usage {
443 input_tokens: response.usage.prompt_tokens as u64,
444 output_tokens: response.usage.completion_tokens as u64,
445 total_tokens: response.usage.total_tokens as u64,
446 cached_input_tokens: response
447 .usage
448 .prompt_tokens_details
449 .as_ref()
450 .and_then(|d| d.cached_tokens)
451 .map(|c| c as u64)
452 .unwrap_or(0),
453 };
454
455 Ok(completion::CompletionResponse {
456 choice,
457 usage,
458 raw_response: response,
459 message_id: None,
460 })
461 }
462}
463
464#[derive(Debug, Serialize, Deserialize)]
465pub(super) struct DeepseekCompletionRequest {
466 model: String,
467 pub messages: Vec<Message>,
468 #[serde(skip_serializing_if = "Option::is_none")]
469 temperature: Option<f64>,
470 #[serde(skip_serializing_if = "Vec::is_empty")]
471 tools: Vec<ToolDefinition>,
472 #[serde(skip_serializing_if = "Option::is_none")]
473 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
474 #[serde(flatten, skip_serializing_if = "Option::is_none")]
475 pub additional_params: Option<serde_json::Value>,
476}
477
478impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
479 type Error = CompletionError;
480
481 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
482 if req.output_schema.is_some() {
483 tracing::warn!("Structured outputs currently not supported for DeepSeek");
484 }
485 let model = req.model.clone().unwrap_or_else(|| model.to_string());
486 let mut full_history: Vec<Message> = match &req.preamble {
487 Some(preamble) => vec![Message::system(preamble)],
488 None => vec![],
489 };
490
491 if let Some(docs) = req.normalized_documents() {
492 let docs: Vec<Message> = docs.try_into()?;
493 full_history.extend(docs);
494 }
495
496 let chat_history: Vec<Message> = req
497 .chat_history
498 .clone()
499 .into_iter()
500 .map(|message| message.try_into())
501 .collect::<Result<Vec<Vec<Message>>, _>>()?
502 .into_iter()
503 .flatten()
504 .collect();
505
506 let mut last_reasoning_content = None;
507 let mut last_assistant_idx = None;
508
509 for message in chat_history {
510 if let Message::Assistant {
511 reasoning_content, ..
512 } = &message
513 {
514 if let Some(content) = reasoning_content {
515 last_reasoning_content = Some(content.clone());
516 } else {
517 last_assistant_idx = Some(full_history.len());
518 full_history.push(message);
519 }
520 } else {
521 full_history.push(message);
522 }
523 }
524
525 if let (Some(idx), Some(reasoning)) = (last_assistant_idx, last_reasoning_content)
528 && let Message::Assistant {
529 ref mut reasoning_content,
530 ..
531 } = full_history[idx]
532 {
533 *reasoning_content = Some(reasoning);
534 }
535
536 let tool_choice = req
537 .tool_choice
538 .clone()
539 .map(crate::providers::openrouter::ToolChoice::try_from)
540 .transpose()?;
541
542 Ok(Self {
543 model: model.to_string(),
544 messages: full_history,
545 temperature: req.temperature,
546 tools: req
547 .tools
548 .clone()
549 .into_iter()
550 .map(ToolDefinition::from)
551 .collect::<Vec<_>>(),
552 tool_choice,
553 additional_params: req.additional_params,
554 })
555 }
556}
557
558#[derive(Clone)]
560pub struct CompletionModel<T = reqwest::Client> {
561 pub client: Client<T>,
562 pub model: String,
563}
564
565impl<T> completion::CompletionModel for CompletionModel<T>
566where
567 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
568{
569 type Response = CompletionResponse;
570 type StreamingResponse = StreamingCompletionResponse;
571
572 type Client = Client<T>;
573
574 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
575 Self {
576 client: client.clone(),
577 model: model.into().to_string(),
578 }
579 }
580
581 async fn completion(
582 &self,
583 completion_request: CompletionRequest,
584 ) -> Result<
585 completion::CompletionResponse<CompletionResponse>,
586 crate::completion::CompletionError,
587 > {
588 let span = if tracing::Span::current().is_disabled() {
589 info_span!(
590 target: "rig::completions",
591 "chat",
592 gen_ai.operation.name = "chat",
593 gen_ai.provider.name = "deepseek",
594 gen_ai.request.model = self.model,
595 gen_ai.system_instructions = tracing::field::Empty,
596 gen_ai.response.id = tracing::field::Empty,
597 gen_ai.response.model = tracing::field::Empty,
598 gen_ai.usage.output_tokens = tracing::field::Empty,
599 gen_ai.usage.input_tokens = tracing::field::Empty,
600 )
601 } else {
602 tracing::Span::current()
603 };
604
605 span.record("gen_ai.system_instructions", &completion_request.preamble);
606
607 let request =
608 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
609
610 if enabled!(Level::TRACE) {
611 tracing::trace!(target: "rig::completions",
612 "DeepSeek completion request: {}",
613 serde_json::to_string_pretty(&request)?
614 );
615 }
616
617 let body = serde_json::to_vec(&request)?;
618 let req = self
619 .client
620 .post("/chat/completions")?
621 .body(body)
622 .map_err(|e| CompletionError::HttpError(e.into()))?;
623
624 async move {
625 let response = self.client.send::<_, Bytes>(req).await?;
626 let status = response.status();
627 let response_body = response.into_body().into_future().await?.to_vec();
628
629 if status.is_success() {
630 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
631 ApiResponse::Ok(response) => {
632 let span = tracing::Span::current();
633 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
634 span.record(
635 "gen_ai.usage.output_tokens",
636 response.usage.completion_tokens,
637 );
638 if enabled!(Level::TRACE) {
639 tracing::trace!(target: "rig::completions",
640 "DeepSeek completion response: {}",
641 serde_json::to_string_pretty(&response)?
642 );
643 }
644 response.try_into()
645 }
646 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
647 }
648 } else {
649 Err(CompletionError::ProviderError(
650 String::from_utf8_lossy(&response_body).to_string(),
651 ))
652 }
653 }
654 .instrument(span)
655 .await
656 }
657
658 async fn stream(
659 &self,
660 completion_request: CompletionRequest,
661 ) -> Result<
662 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
663 CompletionError,
664 > {
665 let preamble = completion_request.preamble.clone();
666 let mut request =
667 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
668
669 let params = json_utils::merge(
670 request.additional_params.unwrap_or(serde_json::json!({})),
671 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
672 );
673
674 request.additional_params = Some(params);
675
676 if enabled!(Level::TRACE) {
677 tracing::trace!(target: "rig::completions",
678 "DeepSeek streaming completion request: {}",
679 serde_json::to_string_pretty(&request)?
680 );
681 }
682
683 let body = serde_json::to_vec(&request)?;
684
685 let req = self
686 .client
687 .post("/chat/completions")?
688 .body(body)
689 .map_err(|e| CompletionError::HttpError(e.into()))?;
690
691 let span = if tracing::Span::current().is_disabled() {
692 info_span!(
693 target: "rig::completions",
694 "chat_streaming",
695 gen_ai.operation.name = "chat_streaming",
696 gen_ai.provider.name = "deepseek",
697 gen_ai.request.model = self.model,
698 gen_ai.system_instructions = preamble,
699 gen_ai.response.id = tracing::field::Empty,
700 gen_ai.response.model = tracing::field::Empty,
701 gen_ai.usage.output_tokens = tracing::field::Empty,
702 gen_ai.usage.input_tokens = tracing::field::Empty,
703 )
704 } else {
705 tracing::Span::current()
706 };
707
708 tracing::Instrument::instrument(
709 send_compatible_streaming_request(self.client.clone(), req),
710 span,
711 )
712 .await
713 }
714}
715
716#[derive(Deserialize, Debug)]
717pub struct StreamingDelta {
718 #[serde(default)]
719 content: Option<String>,
720 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
721 tool_calls: Vec<StreamingToolCall>,
722 reasoning_content: Option<String>,
723}
724
725#[derive(Deserialize, Debug)]
726struct StreamingChoice {
727 delta: StreamingDelta,
728}
729
730#[derive(Deserialize, Debug)]
731struct StreamingCompletionChunk {
732 choices: Vec<StreamingChoice>,
733 usage: Option<Usage>,
734}
735
736#[derive(Clone, Deserialize, Serialize, Debug)]
737pub struct StreamingCompletionResponse {
738 pub usage: Usage,
739}
740
741impl GetTokenUsage for StreamingCompletionResponse {
742 fn token_usage(&self) -> Option<crate::completion::Usage> {
743 let mut usage = crate::completion::Usage::new();
744 usage.input_tokens = self.usage.prompt_tokens as u64;
745 usage.output_tokens = self.usage.completion_tokens as u64;
746 usage.total_tokens = self.usage.total_tokens as u64;
747 usage.cached_input_tokens = self
748 .usage
749 .prompt_tokens_details
750 .as_ref()
751 .and_then(|d| d.cached_tokens)
752 .map(|c| c as u64)
753 .unwrap_or(0);
754
755 Some(usage)
756 }
757}
758
759pub async fn send_compatible_streaming_request<T>(
760 http_client: T,
761 req: Request<Vec<u8>>,
762) -> Result<
763 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
764 CompletionError,
765>
766where
767 T: HttpClientExt + Clone + 'static,
768{
769 let mut event_source = GenericEventSource::new(http_client, req);
770
771 let stream = stream! {
772 let mut final_usage = Usage::new();
773 let mut text_response = String::new();
774 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
775
776 while let Some(event_result) = event_source.next().await {
777 match event_result {
778 Ok(Event::Open) => {
779 tracing::trace!("SSE connection opened");
780 continue;
781 }
782 Ok(Event::Message(message)) => {
783 if message.data.trim().is_empty() || message.data == "[DONE]" {
784 continue;
785 }
786
787 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
788 let Ok(data) = parsed else {
789 let err = parsed.unwrap_err();
790 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
791 continue;
792 };
793
794 if let Some(choice) = data.choices.first() {
795 let delta = &choice.delta;
796
797 if !delta.tool_calls.is_empty() {
798 for tool_call in &delta.tool_calls {
799 let function = &tool_call.function;
800
801 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
803 && empty_or_none(&function.arguments)
804 {
805 let id = tool_call.id.clone().unwrap_or_default();
806 let name = function.name.clone().unwrap();
807 calls.insert(tool_call.index, (id, name, String::new()));
808 }
809 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
811 && let Some(arguments) = &function.arguments
812 && !arguments.is_empty()
813 {
814 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
815 let combined = format!("{}{}", existing_args, arguments);
816 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
817 } else {
818 tracing::debug!("Partial tool call received but tool call was never started.");
819 }
820 }
821 else {
823 let id = tool_call.id.clone().unwrap_or_default();
824 let name = function.name.clone().unwrap_or_default();
825 let arguments_str = function.arguments.clone().unwrap_or_default();
826
827 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
828 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
829 continue;
830 };
831
832 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
833 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
834 ));
835 }
836 }
837 }
838
839 if let Some(content) = &delta.reasoning_content {
841 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
842 id: None,
843 reasoning: content.to_string()
844 });
845 }
846
847 if let Some(content) = &delta.content {
848 text_response += content;
849 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
850 }
851 }
852
853 if let Some(usage) = data.usage {
854 final_usage = usage.clone();
855 }
856 }
857 Err(crate::http_client::Error::StreamEnded) => {
858 break;
859 }
860 Err(err) => {
861 tracing::error!(?err, "SSE error");
862 yield Err(CompletionError::ResponseError(err.to_string()));
863 break;
864 }
865 }
866 }
867
868 event_source.close();
869
870 let mut tool_calls = Vec::new();
871 for (index, (id, name, arguments)) in calls {
873 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
874 continue;
875 };
876
877 tool_calls.push(ToolCall {
878 id: id.clone(),
879 index,
880 r#type: ToolType::Function,
881 function: Function {
882 name: name.clone(),
883 arguments: arguments_json.clone()
884 }
885 });
886 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
887 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
888 ));
889 }
890
891 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
892 StreamingCompletionResponse { usage: final_usage.clone() }
893 ));
894 };
895
896 Ok(crate::streaming::StreamingCompletionResponse::stream(
897 Box::pin(stream),
898 ))
899}
900
901pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
905pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
906
907#[cfg(test)]
909mod tests {
910 use super::*;
911
912 #[test]
913 fn test_deserialize_vec_choice() {
914 let data = r#"[{
915 "finish_reason": "stop",
916 "index": 0,
917 "logprobs": null,
918 "message":{"role":"assistant","content":"Hello, world!"}
919 }]"#;
920
921 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
922 assert_eq!(choices.len(), 1);
923 match &choices.first().unwrap().message {
924 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
925 _ => panic!("Expected assistant message"),
926 }
927 }
928
929 #[test]
930 fn test_deserialize_deepseek_response() {
931 let data = r#"{
932 "choices":[{
933 "finish_reason": "stop",
934 "index": 0,
935 "logprobs": null,
936 "message":{"role":"assistant","content":"Hello, world!"}
937 }],
938 "usage": {
939 "completion_tokens": 0,
940 "prompt_tokens": 0,
941 "prompt_cache_hit_tokens": 0,
942 "prompt_cache_miss_tokens": 0,
943 "total_tokens": 0
944 }
945 }"#;
946
947 let jd = &mut serde_json::Deserializer::from_str(data);
948 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
949 match result {
950 Ok(response) => match &response.choices.first().unwrap().message {
951 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
952 _ => panic!("Expected assistant message"),
953 },
954 Err(err) => {
955 panic!("Deserialization error at {}: {}", err.path(), err);
956 }
957 }
958 }
959
960 #[test]
961 fn test_deserialize_example_response() {
962 let data = r#"
963 {
964 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
965 "object": "chat.completion",
966 "created": 0,
967 "model": "deepseek-chat",
968 "choices": [
969 {
970 "index": 0,
971 "message": {
972 "role": "assistant",
973 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
974 },
975 "logprobs": null,
976 "finish_reason": "stop"
977 }
978 ],
979 "usage": {
980 "prompt_tokens": 13,
981 "completion_tokens": 32,
982 "total_tokens": 45,
983 "prompt_tokens_details": {
984 "cached_tokens": 0
985 },
986 "prompt_cache_hit_tokens": 0,
987 "prompt_cache_miss_tokens": 13
988 },
989 "system_fingerprint": "fp_4b6881f2c5"
990 }
991 "#;
992 let jd = &mut serde_json::Deserializer::from_str(data);
993 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
994
995 match result {
996 Ok(response) => match &response.choices.first().unwrap().message {
997 Message::Assistant { content, .. } => assert_eq!(
998 content,
999 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
1000 ),
1001 _ => panic!("Expected assistant message"),
1002 },
1003 Err(err) => {
1004 panic!("Deserialization error at {}: {}", err.path(), err);
1005 }
1006 }
1007 }
1008
1009 #[test]
1010 fn test_serialize_deserialize_tool_call_message() {
1011 let tool_call_choice_json = r#"
1012 {
1013 "finish_reason": "tool_calls",
1014 "index": 0,
1015 "logprobs": null,
1016 "message": {
1017 "content": "",
1018 "role": "assistant",
1019 "tool_calls": [
1020 {
1021 "function": {
1022 "arguments": "{\"x\":2,\"y\":5}",
1023 "name": "subtract"
1024 },
1025 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1026 "index": 0,
1027 "type": "function"
1028 }
1029 ]
1030 }
1031 }
1032 "#;
1033
1034 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1035
1036 let expected_choice: Choice = Choice {
1037 finish_reason: "tool_calls".to_string(),
1038 index: 0,
1039 logprobs: None,
1040 message: Message::Assistant {
1041 content: "".to_string(),
1042 name: None,
1043 tool_calls: vec![ToolCall {
1044 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1045 function: Function {
1046 name: "subtract".to_string(),
1047 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1048 },
1049 index: 0,
1050 r#type: ToolType::Function,
1051 }],
1052 reasoning_content: None,
1053 },
1054 };
1055
1056 assert_eq!(choice, expected_choice);
1057 }
1058 #[test]
1059 fn test_client_initialization() {
1060 let _client: crate::providers::deepseek::Client =
1061 crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1062 let _client_from_builder: crate::providers::deepseek::Client =
1063 crate::providers::deepseek::Client::builder()
1064 .api_key("dummy-key")
1065 .build()
1066 .expect("Client::builder() failed");
1067 }
1068}