1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51
52 const VERIFY_PATH: &'static str = "/user/balance";
53
54 fn build<H>(
55 _: &crate::client::ClientBuilder<
56 Self::Builder,
57 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58 H,
59 >,
60 ) -> http_client::Result<Self> {
61 Ok(Self)
62 }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66 type Completion = Capable<CompletionModel<H>>;
67 type Embeddings = Nothing;
68 type Transcription = Nothing;
69 #[cfg(feature = "image")]
70 type ImageGeneration = Nothing;
71 #[cfg(feature = "audio")]
72 type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78 type Output = DeepSeekExt;
79 type ApiKey = DeepSeekApiKey;
80
81 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88 type Input = DeepSeekApiKey;
89
90 fn from_env() -> Self {
92 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93 let mut client_builder = Self::builder();
94 client_builder.headers_mut().insert(
95 http::header::CONTENT_TYPE,
96 http::HeaderValue::from_static("application/json"),
97 );
98 let client_builder = client_builder.api_key(&api_key);
99 client_builder.build().unwrap()
100 }
101
102 fn from_val(input: Self::Input) -> Self {
103 Self::new(input).unwrap()
104 }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109 message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115 Ok(T),
116 Err(ApiErrorResponse),
117}
118
119impl From<ApiErrorResponse> for CompletionError {
120 fn from(err: ApiErrorResponse) -> Self {
121 CompletionError::ProviderError(err.message)
122 }
123}
124
125#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct CompletionResponse {
128 pub choices: Vec<Choice>,
130 pub usage: Usage,
131 }
133
134#[derive(Clone, Debug, Serialize, Deserialize, Default)]
135pub struct Usage {
136 pub completion_tokens: u32,
137 pub prompt_tokens: u32,
138 pub prompt_cache_hit_tokens: u32,
139 pub prompt_cache_miss_tokens: u32,
140 pub total_tokens: u32,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub completion_tokens_details: Option<CompletionTokensDetails>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub prompt_tokens_details: Option<PromptTokensDetails>,
145}
146
147impl Usage {
148 fn new() -> Self {
149 Self {
150 completion_tokens: 0,
151 prompt_tokens: 0,
152 prompt_cache_hit_tokens: 0,
153 prompt_cache_miss_tokens: 0,
154 total_tokens: 0,
155 completion_tokens_details: None,
156 prompt_tokens_details: None,
157 }
158 }
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct CompletionTokensDetails {
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub reasoning_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, Default)]
168pub struct PromptTokensDetails {
169 #[serde(skip_serializing_if = "Option::is_none")]
170 pub cached_tokens: Option<u32>,
171}
172
173#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
174pub struct Choice {
175 pub index: usize,
176 pub message: Message,
177 pub logprobs: Option<serde_json::Value>,
178 pub finish_reason: String,
179}
180
181#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
182#[serde(tag = "role", rename_all = "lowercase")]
183pub enum Message {
184 System {
185 content: String,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 name: Option<String>,
188 },
189 User {
190 content: String,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 name: Option<String>,
193 },
194 Assistant {
195 content: String,
196 #[serde(skip_serializing_if = "Option::is_none")]
197 name: Option<String>,
198 #[serde(
199 default,
200 deserialize_with = "json_utils::null_or_vec",
201 skip_serializing_if = "Vec::is_empty"
202 )]
203 tool_calls: Vec<ToolCall>,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 reasoning_content: Option<String>,
207 },
208 #[serde(rename = "tool")]
209 ToolResult {
210 tool_call_id: String,
211 content: String,
212 },
213}
214
215impl Message {
216 pub fn system(content: &str) -> Self {
217 Message::System {
218 content: content.to_owned(),
219 name: None,
220 }
221 }
222}
223
224impl From<message::ToolResult> for Message {
225 fn from(tool_result: message::ToolResult) -> Self {
226 let content = match tool_result.content.first() {
227 message::ToolResultContent::Text(text) => text.text,
228 message::ToolResultContent::Image(_) => String::from("[Image]"),
229 };
230
231 Message::ToolResult {
232 tool_call_id: tool_result.id,
233 content,
234 }
235 }
236}
237
238impl From<message::ToolCall> for ToolCall {
239 fn from(tool_call: message::ToolCall) -> Self {
240 Self {
241 id: tool_call.id,
242 index: 0,
244 r#type: ToolType::Function,
245 function: Function {
246 name: tool_call.function.name,
247 arguments: tool_call.function.arguments,
248 },
249 }
250 }
251}
252
253impl TryFrom<message::Message> for Vec<Message> {
254 type Error = message::MessageError;
255
256 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
257 match message {
258 message::Message::User { content } => {
259 let mut messages = vec![];
261
262 let tool_results = content
263 .clone()
264 .into_iter()
265 .filter_map(|content| match content {
266 message::UserContent::ToolResult(tool_result) => {
267 Some(Message::from(tool_result))
268 }
269 _ => None,
270 })
271 .collect::<Vec<_>>();
272
273 messages.extend(tool_results);
274
275 let text_messages = content
277 .into_iter()
278 .filter_map(|content| match content {
279 message::UserContent::Text(text) => Some(Message::User {
280 content: text.text,
281 name: None,
282 }),
283 message::UserContent::Document(Document {
284 data:
285 DocumentSourceKind::Base64(content)
286 | DocumentSourceKind::String(content),
287 ..
288 }) => Some(Message::User {
289 content,
290 name: None,
291 }),
292 _ => None,
293 })
294 .collect::<Vec<_>>();
295 messages.extend(text_messages);
296
297 Ok(messages)
298 }
299 message::Message::Assistant { content, .. } => {
300 let mut messages: Vec<Message> = vec![];
301 let mut text_content = String::new();
302 let mut reasoning_content = String::new();
303
304 content.iter().for_each(|content| match content {
305 message::AssistantContent::Text(text) => {
306 text_content.push_str(text.text());
307 }
308 message::AssistantContent::Reasoning(reasoning) => {
309 reasoning_content.push_str(&reasoning.reasoning.join("\n"));
310 }
311 _ => {}
312 });
313
314 messages.push(Message::Assistant {
315 content: text_content,
316 name: None,
317 tool_calls: vec![],
318 reasoning_content: if reasoning_content.is_empty() {
319 None
320 } else {
321 Some(reasoning_content)
322 },
323 });
324
325 let tool_calls = content
327 .clone()
328 .into_iter()
329 .filter_map(|content| match content {
330 message::AssistantContent::ToolCall(tool_call) => {
331 Some(ToolCall::from(tool_call))
332 }
333 _ => None,
334 })
335 .collect::<Vec<_>>();
336
337 if !tool_calls.is_empty() {
339 messages.push(Message::Assistant {
340 content: "".to_string(),
341 name: None,
342 tool_calls,
343 reasoning_content: None,
344 });
345 }
346
347 Ok(messages)
348 }
349 }
350 }
351}
352
353#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
354pub struct ToolCall {
355 pub id: String,
356 pub index: usize,
357 #[serde(default)]
358 pub r#type: ToolType,
359 pub function: Function,
360}
361
362#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
363pub struct Function {
364 pub name: String,
365 #[serde(with = "json_utils::stringified_json")]
366 pub arguments: serde_json::Value,
367}
368
369#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
370#[serde(rename_all = "lowercase")]
371pub enum ToolType {
372 #[default]
373 Function,
374}
375
376#[derive(Clone, Debug, Deserialize, Serialize)]
377pub struct ToolDefinition {
378 pub r#type: String,
379 pub function: completion::ToolDefinition,
380}
381
382impl From<crate::completion::ToolDefinition> for ToolDefinition {
383 fn from(tool: crate::completion::ToolDefinition) -> Self {
384 Self {
385 r#type: "function".into(),
386 function: tool,
387 }
388 }
389}
390
391impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
392 type Error = CompletionError;
393
394 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
395 let choice = response.choices.first().ok_or_else(|| {
396 CompletionError::ResponseError("Response contained no choices".to_owned())
397 })?;
398 let content = match &choice.message {
399 Message::Assistant {
400 content,
401 tool_calls,
402 reasoning_content,
403 ..
404 } => {
405 let mut content = if content.trim().is_empty() {
406 vec![]
407 } else {
408 vec![completion::AssistantContent::text(content)]
409 };
410
411 content.extend(
412 tool_calls
413 .iter()
414 .map(|call| {
415 completion::AssistantContent::tool_call(
416 &call.id,
417 &call.function.name,
418 call.function.arguments.clone(),
419 )
420 })
421 .collect::<Vec<_>>(),
422 );
423
424 if let Some(reasoning_content) = reasoning_content {
425 content.push(completion::AssistantContent::reasoning(reasoning_content));
426 }
427
428 Ok(content)
429 }
430 _ => Err(CompletionError::ResponseError(
431 "Response did not contain a valid message or tool call".into(),
432 )),
433 }?;
434
435 let choice = OneOrMany::many(content).map_err(|_| {
436 CompletionError::ResponseError(
437 "Response contained no message or tool call (empty)".to_owned(),
438 )
439 })?;
440
441 let usage = completion::Usage {
442 input_tokens: response.usage.prompt_tokens as u64,
443 output_tokens: response.usage.completion_tokens as u64,
444 total_tokens: response.usage.total_tokens as u64,
445 cached_input_tokens: response
446 .usage
447 .prompt_tokens_details
448 .as_ref()
449 .and_then(|d| d.cached_tokens)
450 .map(|c| c as u64)
451 .unwrap_or(0),
452 };
453
454 Ok(completion::CompletionResponse {
455 choice,
456 usage,
457 raw_response: response,
458 })
459 }
460}
461
462#[derive(Debug, Serialize, Deserialize)]
463pub(super) struct DeepseekCompletionRequest {
464 model: String,
465 pub messages: Vec<Message>,
466 #[serde(skip_serializing_if = "Option::is_none")]
467 temperature: Option<f64>,
468 #[serde(skip_serializing_if = "Vec::is_empty")]
469 tools: Vec<ToolDefinition>,
470 #[serde(skip_serializing_if = "Option::is_none")]
471 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
472 #[serde(flatten, skip_serializing_if = "Option::is_none")]
473 pub additional_params: Option<serde_json::Value>,
474}
475
476impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
477 type Error = CompletionError;
478
479 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
480 let mut full_history: Vec<Message> = match &req.preamble {
481 Some(preamble) => vec![Message::system(preamble)],
482 None => vec![],
483 };
484
485 if let Some(docs) = req.normalized_documents() {
486 let docs: Vec<Message> = docs.try_into()?;
487 full_history.extend(docs);
488 }
489
490 let chat_history: Vec<Message> = req
491 .chat_history
492 .clone()
493 .into_iter()
494 .map(|message| message.try_into())
495 .collect::<Result<Vec<Vec<Message>>, _>>()?
496 .into_iter()
497 .flatten()
498 .collect();
499
500 let mut last_reasoning_content = None;
501 let mut last_assistant_idx = None;
502
503 for message in chat_history {
504 if let Message::Assistant {
505 reasoning_content, ..
506 } = &message
507 {
508 if let Some(content) = reasoning_content {
509 last_reasoning_content = Some(content.clone());
510 } else {
511 last_assistant_idx = Some(full_history.len());
512 full_history.push(message);
513 }
514 } else {
515 full_history.push(message);
516 }
517 }
518
519 if let (Some(idx), Some(reasoning)) = (last_assistant_idx, last_reasoning_content)
522 && let Message::Assistant {
523 ref mut reasoning_content,
524 ..
525 } = full_history[idx]
526 {
527 *reasoning_content = Some(reasoning);
528 }
529
530 let tool_choice = req
531 .tool_choice
532 .clone()
533 .map(crate::providers::openrouter::ToolChoice::try_from)
534 .transpose()?;
535
536 Ok(Self {
537 model: model.to_string(),
538 messages: full_history,
539 temperature: req.temperature,
540 tools: req
541 .tools
542 .clone()
543 .into_iter()
544 .map(ToolDefinition::from)
545 .collect::<Vec<_>>(),
546 tool_choice,
547 additional_params: req.additional_params,
548 })
549 }
550}
551
552#[derive(Clone)]
554pub struct CompletionModel<T = reqwest::Client> {
555 pub client: Client<T>,
556 pub model: String,
557}
558
559impl<T> completion::CompletionModel for CompletionModel<T>
560where
561 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
562{
563 type Response = CompletionResponse;
564 type StreamingResponse = StreamingCompletionResponse;
565
566 type Client = Client<T>;
567
568 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
569 Self {
570 client: client.clone(),
571 model: model.into().to_string(),
572 }
573 }
574
575 async fn completion(
576 &self,
577 completion_request: CompletionRequest,
578 ) -> Result<
579 completion::CompletionResponse<CompletionResponse>,
580 crate::completion::CompletionError,
581 > {
582 let span = if tracing::Span::current().is_disabled() {
583 info_span!(
584 target: "rig::completions",
585 "chat",
586 gen_ai.operation.name = "chat",
587 gen_ai.provider.name = "deepseek",
588 gen_ai.request.model = self.model,
589 gen_ai.system_instructions = tracing::field::Empty,
590 gen_ai.response.id = tracing::field::Empty,
591 gen_ai.response.model = tracing::field::Empty,
592 gen_ai.usage.output_tokens = tracing::field::Empty,
593 gen_ai.usage.input_tokens = tracing::field::Empty,
594 )
595 } else {
596 tracing::Span::current()
597 };
598
599 span.record("gen_ai.system_instructions", &completion_request.preamble);
600
601 let request =
602 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
603
604 if enabled!(Level::TRACE) {
605 tracing::trace!(target: "rig::completions",
606 "DeepSeek completion request: {}",
607 serde_json::to_string_pretty(&request)?
608 );
609 }
610
611 let body = serde_json::to_vec(&request)?;
612 let req = self
613 .client
614 .post("/chat/completions")?
615 .body(body)
616 .map_err(|e| CompletionError::HttpError(e.into()))?;
617
618 async move {
619 let response = self.client.send::<_, Bytes>(req).await?;
620 let status = response.status();
621 let response_body = response.into_body().into_future().await?.to_vec();
622
623 if status.is_success() {
624 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
625 ApiResponse::Ok(response) => {
626 let span = tracing::Span::current();
627 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
628 span.record(
629 "gen_ai.usage.output_tokens",
630 response.usage.completion_tokens,
631 );
632 if enabled!(Level::TRACE) {
633 tracing::trace!(target: "rig::completions",
634 "DeepSeek completion response: {}",
635 serde_json::to_string_pretty(&response)?
636 );
637 }
638 response.try_into()
639 }
640 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
641 }
642 } else {
643 Err(CompletionError::ProviderError(
644 String::from_utf8_lossy(&response_body).to_string(),
645 ))
646 }
647 }
648 .instrument(span)
649 .await
650 }
651
652 async fn stream(
653 &self,
654 completion_request: CompletionRequest,
655 ) -> Result<
656 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
657 CompletionError,
658 > {
659 let preamble = completion_request.preamble.clone();
660 let mut request =
661 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
662
663 let params = json_utils::merge(
664 request.additional_params.unwrap_or(serde_json::json!({})),
665 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
666 );
667
668 request.additional_params = Some(params);
669
670 if enabled!(Level::TRACE) {
671 tracing::trace!(target: "rig::completions",
672 "DeepSeek streaming completion request: {}",
673 serde_json::to_string_pretty(&request)?
674 );
675 }
676
677 let body = serde_json::to_vec(&request)?;
678
679 let req = self
680 .client
681 .post("/chat/completions")?
682 .body(body)
683 .map_err(|e| CompletionError::HttpError(e.into()))?;
684
685 let span = if tracing::Span::current().is_disabled() {
686 info_span!(
687 target: "rig::completions",
688 "chat_streaming",
689 gen_ai.operation.name = "chat_streaming",
690 gen_ai.provider.name = "deepseek",
691 gen_ai.request.model = self.model,
692 gen_ai.system_instructions = preamble,
693 gen_ai.response.id = tracing::field::Empty,
694 gen_ai.response.model = tracing::field::Empty,
695 gen_ai.usage.output_tokens = tracing::field::Empty,
696 gen_ai.usage.input_tokens = tracing::field::Empty,
697 )
698 } else {
699 tracing::Span::current()
700 };
701
702 tracing::Instrument::instrument(
703 send_compatible_streaming_request(self.client.clone(), req),
704 span,
705 )
706 .await
707 }
708}
709
710#[derive(Deserialize, Debug)]
711pub struct StreamingDelta {
712 #[serde(default)]
713 content: Option<String>,
714 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
715 tool_calls: Vec<StreamingToolCall>,
716 reasoning_content: Option<String>,
717}
718
719#[derive(Deserialize, Debug)]
720struct StreamingChoice {
721 delta: StreamingDelta,
722}
723
724#[derive(Deserialize, Debug)]
725struct StreamingCompletionChunk {
726 choices: Vec<StreamingChoice>,
727 usage: Option<Usage>,
728}
729
730#[derive(Clone, Deserialize, Serialize, Debug)]
731pub struct StreamingCompletionResponse {
732 pub usage: Usage,
733}
734
735impl GetTokenUsage for StreamingCompletionResponse {
736 fn token_usage(&self) -> Option<crate::completion::Usage> {
737 let mut usage = crate::completion::Usage::new();
738 usage.input_tokens = self.usage.prompt_tokens as u64;
739 usage.output_tokens = self.usage.completion_tokens as u64;
740 usage.total_tokens = self.usage.total_tokens as u64;
741 usage.cached_input_tokens = self
742 .usage
743 .prompt_tokens_details
744 .as_ref()
745 .and_then(|d| d.cached_tokens)
746 .map(|c| c as u64)
747 .unwrap_or(0);
748
749 Some(usage)
750 }
751}
752
753pub async fn send_compatible_streaming_request<T>(
754 http_client: T,
755 req: Request<Vec<u8>>,
756) -> Result<
757 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
758 CompletionError,
759>
760where
761 T: HttpClientExt + Clone + 'static,
762{
763 let mut event_source = GenericEventSource::new(http_client, req);
764
765 let stream = stream! {
766 let mut final_usage = Usage::new();
767 let mut text_response = String::new();
768 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
769
770 while let Some(event_result) = event_source.next().await {
771 match event_result {
772 Ok(Event::Open) => {
773 tracing::trace!("SSE connection opened");
774 continue;
775 }
776 Ok(Event::Message(message)) => {
777 if message.data.trim().is_empty() || message.data == "[DONE]" {
778 continue;
779 }
780
781 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
782 let Ok(data) = parsed else {
783 let err = parsed.unwrap_err();
784 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
785 continue;
786 };
787
788 if let Some(choice) = data.choices.first() {
789 let delta = &choice.delta;
790
791 if !delta.tool_calls.is_empty() {
792 for tool_call in &delta.tool_calls {
793 let function = &tool_call.function;
794
795 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
797 && empty_or_none(&function.arguments)
798 {
799 let id = tool_call.id.clone().unwrap_or_default();
800 let name = function.name.clone().unwrap();
801 calls.insert(tool_call.index, (id, name, String::new()));
802 }
803 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
805 && let Some(arguments) = &function.arguments
806 && !arguments.is_empty()
807 {
808 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
809 let combined = format!("{}{}", existing_args, arguments);
810 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
811 } else {
812 tracing::debug!("Partial tool call received but tool call was never started.");
813 }
814 }
815 else {
817 let id = tool_call.id.clone().unwrap_or_default();
818 let name = function.name.clone().unwrap_or_default();
819 let arguments_str = function.arguments.clone().unwrap_or_default();
820
821 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
822 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
823 continue;
824 };
825
826 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
827 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
828 ));
829 }
830 }
831 }
832
833 if let Some(content) = &delta.reasoning_content {
835 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
836 id: None,
837 reasoning: content.to_string()
838 });
839 }
840
841 if let Some(content) = &delta.content {
842 text_response += content;
843 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
844 }
845 }
846
847 if let Some(usage) = data.usage {
848 final_usage = usage.clone();
849 }
850 }
851 Err(crate::http_client::Error::StreamEnded) => {
852 break;
853 }
854 Err(err) => {
855 tracing::error!(?err, "SSE error");
856 yield Err(CompletionError::ResponseError(err.to_string()));
857 break;
858 }
859 }
860 }
861
862 event_source.close();
863
864 let mut tool_calls = Vec::new();
865 for (index, (id, name, arguments)) in calls {
867 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
868 continue;
869 };
870
871 tool_calls.push(ToolCall {
872 id: id.clone(),
873 index,
874 r#type: ToolType::Function,
875 function: Function {
876 name: name.clone(),
877 arguments: arguments_json.clone()
878 }
879 });
880 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
881 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
882 ));
883 }
884
885 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
886 StreamingCompletionResponse { usage: final_usage.clone() }
887 ));
888 };
889
890 Ok(crate::streaming::StreamingCompletionResponse::stream(
891 Box::pin(stream),
892 ))
893}
894
895pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
899pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
900
901#[cfg(test)]
903mod tests {
904
905 use super::*;
906
907 #[test]
908 fn test_deserialize_vec_choice() {
909 let data = r#"[{
910 "finish_reason": "stop",
911 "index": 0,
912 "logprobs": null,
913 "message":{"role":"assistant","content":"Hello, world!"}
914 }]"#;
915
916 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
917 assert_eq!(choices.len(), 1);
918 match &choices.first().unwrap().message {
919 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
920 _ => panic!("Expected assistant message"),
921 }
922 }
923
924 #[test]
925 fn test_deserialize_deepseek_response() {
926 let data = r#"{
927 "choices":[{
928 "finish_reason": "stop",
929 "index": 0,
930 "logprobs": null,
931 "message":{"role":"assistant","content":"Hello, world!"}
932 }],
933 "usage": {
934 "completion_tokens": 0,
935 "prompt_tokens": 0,
936 "prompt_cache_hit_tokens": 0,
937 "prompt_cache_miss_tokens": 0,
938 "total_tokens": 0
939 }
940 }"#;
941
942 let jd = &mut serde_json::Deserializer::from_str(data);
943 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
944 match result {
945 Ok(response) => match &response.choices.first().unwrap().message {
946 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
947 _ => panic!("Expected assistant message"),
948 },
949 Err(err) => {
950 panic!("Deserialization error at {}: {}", err.path(), err);
951 }
952 }
953 }
954
955 #[test]
956 fn test_deserialize_example_response() {
957 let data = r#"
958 {
959 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
960 "object": "chat.completion",
961 "created": 0,
962 "model": "deepseek-chat",
963 "choices": [
964 {
965 "index": 0,
966 "message": {
967 "role": "assistant",
968 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
969 },
970 "logprobs": null,
971 "finish_reason": "stop"
972 }
973 ],
974 "usage": {
975 "prompt_tokens": 13,
976 "completion_tokens": 32,
977 "total_tokens": 45,
978 "prompt_tokens_details": {
979 "cached_tokens": 0
980 },
981 "prompt_cache_hit_tokens": 0,
982 "prompt_cache_miss_tokens": 13
983 },
984 "system_fingerprint": "fp_4b6881f2c5"
985 }
986 "#;
987 let jd = &mut serde_json::Deserializer::from_str(data);
988 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
989
990 match result {
991 Ok(response) => match &response.choices.first().unwrap().message {
992 Message::Assistant { content, .. } => assert_eq!(
993 content,
994 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
995 ),
996 _ => panic!("Expected assistant message"),
997 },
998 Err(err) => {
999 panic!("Deserialization error at {}: {}", err.path(), err);
1000 }
1001 }
1002 }
1003
1004 #[test]
1005 fn test_serialize_deserialize_tool_call_message() {
1006 let tool_call_choice_json = r#"
1007 {
1008 "finish_reason": "tool_calls",
1009 "index": 0,
1010 "logprobs": null,
1011 "message": {
1012 "content": "",
1013 "role": "assistant",
1014 "tool_calls": [
1015 {
1016 "function": {
1017 "arguments": "{\"x\":2,\"y\":5}",
1018 "name": "subtract"
1019 },
1020 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1021 "index": 0,
1022 "type": "function"
1023 }
1024 ]
1025 }
1026 }
1027 "#;
1028
1029 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1030
1031 let expected_choice: Choice = Choice {
1032 finish_reason: "tool_calls".to_string(),
1033 index: 0,
1034 logprobs: None,
1035 message: Message::Assistant {
1036 content: "".to_string(),
1037 name: None,
1038 tool_calls: vec![ToolCall {
1039 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1040 function: Function {
1041 name: "subtract".to_string(),
1042 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1043 },
1044 index: 0,
1045 r#type: ToolType::Function,
1046 }],
1047 reasoning_content: None,
1048 },
1049 };
1050
1051 assert_eq!(choice, expected_choice);
1052 }
1053}