1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51
52 const VERIFY_PATH: &'static str = "/user/balance";
53
54 fn build<H>(
55 _: &crate::client::ClientBuilder<
56 Self::Builder,
57 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58 H,
59 >,
60 ) -> http_client::Result<Self> {
61 Ok(Self)
62 }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66 type Completion = Capable<CompletionModel<H>>;
67 type Embeddings = Nothing;
68 type Transcription = Nothing;
69 #[cfg(feature = "image")]
70 type ImageGeneration = Nothing;
71 #[cfg(feature = "audio")]
72 type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78 type Output = DeepSeekExt;
79 type ApiKey = DeepSeekApiKey;
80
81 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88 type Input = DeepSeekApiKey;
89
90 fn from_env() -> Self {
92 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93 let mut client_builder = Self::builder();
94 client_builder.headers_mut().insert(
95 http::header::CONTENT_TYPE,
96 http::HeaderValue::from_static("application/json"),
97 );
98 let client_builder = client_builder.api_key(&api_key);
99 client_builder.build().unwrap()
100 }
101
102 fn from_val(input: Self::Input) -> Self {
103 Self::new(input).unwrap()
104 }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109 message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115 Ok(T),
116 Err(ApiErrorResponse),
117}
118
119impl From<ApiErrorResponse> for CompletionError {
120 fn from(err: ApiErrorResponse) -> Self {
121 CompletionError::ProviderError(err.message)
122 }
123}
124
125#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct CompletionResponse {
128 pub choices: Vec<Choice>,
130 pub usage: Usage,
131 }
133
134#[derive(Clone, Debug, Serialize, Deserialize, Default)]
135pub struct Usage {
136 pub completion_tokens: u32,
137 pub prompt_tokens: u32,
138 pub prompt_cache_hit_tokens: u32,
139 pub prompt_cache_miss_tokens: u32,
140 pub total_tokens: u32,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub completion_tokens_details: Option<CompletionTokensDetails>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub prompt_tokens_details: Option<PromptTokensDetails>,
145}
146
147impl Usage {
148 fn new() -> Self {
149 Self {
150 completion_tokens: 0,
151 prompt_tokens: 0,
152 prompt_cache_hit_tokens: 0,
153 prompt_cache_miss_tokens: 0,
154 total_tokens: 0,
155 completion_tokens_details: None,
156 prompt_tokens_details: None,
157 }
158 }
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct CompletionTokensDetails {
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub reasoning_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, Default)]
168pub struct PromptTokensDetails {
169 #[serde(skip_serializing_if = "Option::is_none")]
170 pub cached_tokens: Option<u32>,
171}
172
173#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
174pub struct Choice {
175 pub index: usize,
176 pub message: Message,
177 pub logprobs: Option<serde_json::Value>,
178 pub finish_reason: String,
179}
180
181#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
182#[serde(tag = "role", rename_all = "lowercase")]
183pub enum Message {
184 System {
185 content: String,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 name: Option<String>,
188 },
189 User {
190 content: String,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 name: Option<String>,
193 },
194 Assistant {
195 content: String,
196 #[serde(skip_serializing_if = "Option::is_none")]
197 name: Option<String>,
198 #[serde(
199 default,
200 deserialize_with = "json_utils::null_or_vec",
201 skip_serializing_if = "Vec::is_empty"
202 )]
203 tool_calls: Vec<ToolCall>,
204 #[serde(skip_serializing_if = "Option::is_none")]
206 reasoning_content: Option<String>,
207 },
208 #[serde(rename = "tool")]
209 ToolResult {
210 tool_call_id: String,
211 content: String,
212 },
213}
214
215impl Message {
216 pub fn system(content: &str) -> Self {
217 Message::System {
218 content: content.to_owned(),
219 name: None,
220 }
221 }
222}
223
224impl From<message::ToolResult> for Message {
225 fn from(tool_result: message::ToolResult) -> Self {
226 let content = match tool_result.content.first() {
227 message::ToolResultContent::Text(text) => text.text,
228 message::ToolResultContent::Image(_) => String::from("[Image]"),
229 };
230
231 Message::ToolResult {
232 tool_call_id: tool_result.id,
233 content,
234 }
235 }
236}
237
238impl From<message::ToolCall> for ToolCall {
239 fn from(tool_call: message::ToolCall) -> Self {
240 Self {
241 id: tool_call.id,
242 index: 0,
244 r#type: ToolType::Function,
245 function: Function {
246 name: tool_call.function.name,
247 arguments: tool_call.function.arguments,
248 },
249 }
250 }
251}
252
253impl TryFrom<message::Message> for Vec<Message> {
254 type Error = message::MessageError;
255
256 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
257 match message {
258 message::Message::User { content } => {
259 let mut messages = vec![];
261
262 let tool_results = content
263 .clone()
264 .into_iter()
265 .filter_map(|content| match content {
266 message::UserContent::ToolResult(tool_result) => {
267 Some(Message::from(tool_result))
268 }
269 _ => None,
270 })
271 .collect::<Vec<_>>();
272
273 messages.extend(tool_results);
274
275 let text_messages = content
277 .into_iter()
278 .filter_map(|content| match content {
279 message::UserContent::Text(text) => Some(Message::User {
280 content: text.text,
281 name: None,
282 }),
283 message::UserContent::Document(Document {
284 data:
285 DocumentSourceKind::Base64(content)
286 | DocumentSourceKind::String(content),
287 ..
288 }) => Some(Message::User {
289 content,
290 name: None,
291 }),
292 _ => None,
293 })
294 .collect::<Vec<_>>();
295 messages.extend(text_messages);
296
297 Ok(messages)
298 }
299 message::Message::Assistant { content, .. } => {
300 let mut messages: Vec<Message> = vec![];
301 let mut text_content = String::new();
302 let mut reasoning_content = String::new();
303
304 content.iter().for_each(|content| match content {
305 message::AssistantContent::Text(text) => {
306 text_content.push_str(text.text());
307 }
308 message::AssistantContent::Reasoning(reasoning) => {
309 reasoning_content.push_str(&reasoning.reasoning.join("\n"));
310 }
311 _ => {}
312 });
313
314 messages.push(Message::Assistant {
315 content: text_content,
316 name: None,
317 tool_calls: vec![],
318 reasoning_content: if reasoning_content.is_empty() {
319 None
320 } else {
321 Some(reasoning_content)
322 },
323 });
324
325 let tool_calls = content
327 .clone()
328 .into_iter()
329 .filter_map(|content| match content {
330 message::AssistantContent::ToolCall(tool_call) => {
331 Some(ToolCall::from(tool_call))
332 }
333 _ => None,
334 })
335 .collect::<Vec<_>>();
336
337 if !tool_calls.is_empty() {
339 messages.push(Message::Assistant {
340 content: "".to_string(),
341 name: None,
342 tool_calls,
343 reasoning_content: None,
344 });
345 }
346
347 Ok(messages)
348 }
349 }
350 }
351}
352
353#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
354pub struct ToolCall {
355 pub id: String,
356 pub index: usize,
357 #[serde(default)]
358 pub r#type: ToolType,
359 pub function: Function,
360}
361
362#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
363pub struct Function {
364 pub name: String,
365 #[serde(with = "json_utils::stringified_json")]
366 pub arguments: serde_json::Value,
367}
368
369#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
370#[serde(rename_all = "lowercase")]
371pub enum ToolType {
372 #[default]
373 Function,
374}
375
376#[derive(Clone, Debug, Deserialize, Serialize)]
377pub struct ToolDefinition {
378 pub r#type: String,
379 pub function: completion::ToolDefinition,
380}
381
382impl From<crate::completion::ToolDefinition> for ToolDefinition {
383 fn from(tool: crate::completion::ToolDefinition) -> Self {
384 Self {
385 r#type: "function".into(),
386 function: tool,
387 }
388 }
389}
390
391impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
392 type Error = CompletionError;
393
394 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
395 let choice = response.choices.first().ok_or_else(|| {
396 CompletionError::ResponseError("Response contained no choices".to_owned())
397 })?;
398 let content = match &choice.message {
399 Message::Assistant {
400 content,
401 tool_calls,
402 ..
403 } => {
404 let mut content = if content.trim().is_empty() {
405 vec![]
406 } else {
407 vec![completion::AssistantContent::text(content)]
408 };
409
410 content.extend(
411 tool_calls
412 .iter()
413 .map(|call| {
414 completion::AssistantContent::tool_call(
415 &call.id,
416 &call.function.name,
417 call.function.arguments.clone(),
418 )
419 })
420 .collect::<Vec<_>>(),
421 );
422 Ok(content)
423 }
424 _ => Err(CompletionError::ResponseError(
425 "Response did not contain a valid message or tool call".into(),
426 )),
427 }?;
428
429 let choice = OneOrMany::many(content).map_err(|_| {
430 CompletionError::ResponseError(
431 "Response contained no message or tool call (empty)".to_owned(),
432 )
433 })?;
434
435 let usage = completion::Usage {
436 input_tokens: response.usage.prompt_tokens as u64,
437 output_tokens: response.usage.completion_tokens as u64,
438 total_tokens: response.usage.total_tokens as u64,
439 };
440
441 Ok(completion::CompletionResponse {
442 choice,
443 usage,
444 raw_response: response,
445 })
446 }
447}
448
449#[derive(Debug, Serialize, Deserialize)]
450pub(super) struct DeepseekCompletionRequest {
451 model: String,
452 pub messages: Vec<Message>,
453 #[serde(skip_serializing_if = "Option::is_none")]
454 temperature: Option<f64>,
455 #[serde(skip_serializing_if = "Vec::is_empty")]
456 tools: Vec<ToolDefinition>,
457 #[serde(skip_serializing_if = "Option::is_none")]
458 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
459 #[serde(flatten, skip_serializing_if = "Option::is_none")]
460 pub additional_params: Option<serde_json::Value>,
461}
462
463impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
464 type Error = CompletionError;
465
466 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
467 let mut full_history: Vec<Message> = match &req.preamble {
468 Some(preamble) => vec![Message::system(preamble)],
469 None => vec![],
470 };
471
472 if let Some(docs) = req.normalized_documents() {
473 let docs: Vec<Message> = docs.try_into()?;
474 full_history.extend(docs);
475 }
476
477 let chat_history: Vec<Message> = req
478 .chat_history
479 .clone()
480 .into_iter()
481 .map(|message| message.try_into())
482 .collect::<Result<Vec<Vec<Message>>, _>>()?
483 .into_iter()
484 .flatten()
485 .collect();
486
487 full_history.extend(chat_history);
488
489 let tool_choice = req
490 .tool_choice
491 .clone()
492 .map(crate::providers::openrouter::ToolChoice::try_from)
493 .transpose()?;
494
495 Ok(Self {
496 model: model.to_string(),
497 messages: full_history,
498 temperature: req.temperature,
499 tools: req
500 .tools
501 .clone()
502 .into_iter()
503 .map(ToolDefinition::from)
504 .collect::<Vec<_>>(),
505 tool_choice,
506 additional_params: req.additional_params,
507 })
508 }
509}
510
511#[derive(Clone)]
513pub struct CompletionModel<T = reqwest::Client> {
514 pub client: Client<T>,
515 pub model: String,
516}
517
518impl<T> completion::CompletionModel for CompletionModel<T>
519where
520 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
521{
522 type Response = CompletionResponse;
523 type StreamingResponse = StreamingCompletionResponse;
524
525 type Client = Client<T>;
526
527 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
528 Self {
529 client: client.clone(),
530 model: model.into().to_string(),
531 }
532 }
533
534 async fn completion(
535 &self,
536 completion_request: CompletionRequest,
537 ) -> Result<
538 completion::CompletionResponse<CompletionResponse>,
539 crate::completion::CompletionError,
540 > {
541 let span = if tracing::Span::current().is_disabled() {
542 info_span!(
543 target: "rig::completions",
544 "chat",
545 gen_ai.operation.name = "chat",
546 gen_ai.provider.name = "deepseek",
547 gen_ai.request.model = self.model,
548 gen_ai.system_instructions = tracing::field::Empty,
549 gen_ai.response.id = tracing::field::Empty,
550 gen_ai.response.model = tracing::field::Empty,
551 gen_ai.usage.output_tokens = tracing::field::Empty,
552 gen_ai.usage.input_tokens = tracing::field::Empty,
553 )
554 } else {
555 tracing::Span::current()
556 };
557
558 span.record("gen_ai.system_instructions", &completion_request.preamble);
559
560 let request =
561 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
562
563 if enabled!(Level::TRACE) {
564 tracing::trace!(target: "rig::completions",
565 "DeepSeek completion request: {}",
566 serde_json::to_string_pretty(&request)?
567 );
568 }
569
570 let body = serde_json::to_vec(&request)?;
571 let req = self
572 .client
573 .post("/chat/completions")?
574 .body(body)
575 .map_err(|e| CompletionError::HttpError(e.into()))?;
576
577 async move {
578 let response = self.client.send::<_, Bytes>(req).await?;
579 let status = response.status();
580 let response_body = response.into_body().into_future().await?.to_vec();
581
582 if status.is_success() {
583 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
584 ApiResponse::Ok(response) => {
585 let span = tracing::Span::current();
586 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
587 span.record(
588 "gen_ai.usage.output_tokens",
589 response.usage.completion_tokens,
590 );
591 if enabled!(Level::TRACE) {
592 tracing::trace!(target: "rig::completions",
593 "DeepSeek completion response: {}",
594 serde_json::to_string_pretty(&response)?
595 );
596 }
597 response.try_into()
598 }
599 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
600 }
601 } else {
602 Err(CompletionError::ProviderError(
603 String::from_utf8_lossy(&response_body).to_string(),
604 ))
605 }
606 }
607 .instrument(span)
608 .await
609 }
610
611 async fn stream(
612 &self,
613 completion_request: CompletionRequest,
614 ) -> Result<
615 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
616 CompletionError,
617 > {
618 let preamble = completion_request.preamble.clone();
619 let mut request =
620 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
621
622 let params = json_utils::merge(
623 request.additional_params.unwrap_or(serde_json::json!({})),
624 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
625 );
626
627 request.additional_params = Some(params);
628
629 if enabled!(Level::TRACE) {
630 tracing::trace!(target: "rig::completions",
631 "DeepSeek streaming completion request: {}",
632 serde_json::to_string_pretty(&request)?
633 );
634 }
635
636 let body = serde_json::to_vec(&request)?;
637
638 let req = self
639 .client
640 .post("/chat/completions")?
641 .body(body)
642 .map_err(|e| CompletionError::HttpError(e.into()))?;
643
644 let span = if tracing::Span::current().is_disabled() {
645 info_span!(
646 target: "rig::completions",
647 "chat_streaming",
648 gen_ai.operation.name = "chat_streaming",
649 gen_ai.provider.name = "deepseek",
650 gen_ai.request.model = self.model,
651 gen_ai.system_instructions = preamble,
652 gen_ai.response.id = tracing::field::Empty,
653 gen_ai.response.model = tracing::field::Empty,
654 gen_ai.usage.output_tokens = tracing::field::Empty,
655 gen_ai.usage.input_tokens = tracing::field::Empty,
656 )
657 } else {
658 tracing::Span::current()
659 };
660
661 tracing::Instrument::instrument(
662 send_compatible_streaming_request(self.client.clone(), req),
663 span,
664 )
665 .await
666 }
667}
668
669#[derive(Deserialize, Debug)]
670pub struct StreamingDelta {
671 #[serde(default)]
672 content: Option<String>,
673 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
674 tool_calls: Vec<StreamingToolCall>,
675 reasoning_content: Option<String>,
676}
677
678#[derive(Deserialize, Debug)]
679struct StreamingChoice {
680 delta: StreamingDelta,
681}
682
683#[derive(Deserialize, Debug)]
684struct StreamingCompletionChunk {
685 choices: Vec<StreamingChoice>,
686 usage: Option<Usage>,
687}
688
689#[derive(Clone, Deserialize, Serialize, Debug)]
690pub struct StreamingCompletionResponse {
691 pub usage: Usage,
692}
693
694impl GetTokenUsage for StreamingCompletionResponse {
695 fn token_usage(&self) -> Option<crate::completion::Usage> {
696 let mut usage = crate::completion::Usage::new();
697 usage.input_tokens = self.usage.prompt_tokens as u64;
698 usage.output_tokens = self.usage.completion_tokens as u64;
699 usage.total_tokens = self.usage.total_tokens as u64;
700
701 Some(usage)
702 }
703}
704
705pub async fn send_compatible_streaming_request<T>(
706 http_client: T,
707 req: Request<Vec<u8>>,
708) -> Result<
709 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
710 CompletionError,
711>
712where
713 T: HttpClientExt + Clone + 'static,
714{
715 let mut event_source = GenericEventSource::new(http_client, req);
716
717 let stream = stream! {
718 let mut final_usage = Usage::new();
719 let mut text_response = String::new();
720 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
721
722 while let Some(event_result) = event_source.next().await {
723 match event_result {
724 Ok(Event::Open) => {
725 tracing::trace!("SSE connection opened");
726 continue;
727 }
728 Ok(Event::Message(message)) => {
729 if message.data.trim().is_empty() || message.data == "[DONE]" {
730 continue;
731 }
732
733 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
734 let Ok(data) = parsed else {
735 let err = parsed.unwrap_err();
736 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
737 continue;
738 };
739
740 if let Some(choice) = data.choices.first() {
741 let delta = &choice.delta;
742
743 if !delta.tool_calls.is_empty() {
744 for tool_call in &delta.tool_calls {
745 let function = &tool_call.function;
746
747 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
749 && empty_or_none(&function.arguments)
750 {
751 let id = tool_call.id.clone().unwrap_or_default();
752 let name = function.name.clone().unwrap();
753 calls.insert(tool_call.index, (id, name, String::new()));
754 }
755 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
757 && let Some(arguments) = &function.arguments
758 && !arguments.is_empty()
759 {
760 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
761 let combined = format!("{}{}", existing_args, arguments);
762 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
763 } else {
764 tracing::debug!("Partial tool call received but tool call was never started.");
765 }
766 }
767 else {
769 let id = tool_call.id.clone().unwrap_or_default();
770 let name = function.name.clone().unwrap_or_default();
771 let arguments_str = function.arguments.clone().unwrap_or_default();
772
773 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
774 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
775 continue;
776 };
777
778 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
779 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
780 ));
781 }
782 }
783 }
784
785 if let Some(content) = &delta.reasoning_content {
787 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
788 id: None,
789 reasoning: content.to_string()
790 });
791 }
792
793 if let Some(content) = &delta.content {
794 text_response += content;
795 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
796 }
797 }
798
799 if let Some(usage) = data.usage {
800 final_usage = usage.clone();
801 }
802 }
803 Err(crate::http_client::Error::StreamEnded) => {
804 break;
805 }
806 Err(err) => {
807 tracing::error!(?err, "SSE error");
808 yield Err(CompletionError::ResponseError(err.to_string()));
809 break;
810 }
811 }
812 }
813
814 event_source.close();
815
816 let mut tool_calls = Vec::new();
817 for (index, (id, name, arguments)) in calls {
819 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
820 continue;
821 };
822
823 tool_calls.push(ToolCall {
824 id: id.clone(),
825 index,
826 r#type: ToolType::Function,
827 function: Function {
828 name: name.clone(),
829 arguments: arguments_json.clone()
830 }
831 });
832 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
833 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
834 ));
835 }
836
837 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
838 StreamingCompletionResponse { usage: final_usage.clone() }
839 ));
840 };
841
842 Ok(crate::streaming::StreamingCompletionResponse::stream(
843 Box::pin(stream),
844 ))
845}
846
847pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
851pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
852
853#[cfg(test)]
855mod tests {
856
857 use super::*;
858
859 #[test]
860 fn test_deserialize_vec_choice() {
861 let data = r#"[{
862 "finish_reason": "stop",
863 "index": 0,
864 "logprobs": null,
865 "message":{"role":"assistant","content":"Hello, world!"}
866 }]"#;
867
868 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
869 assert_eq!(choices.len(), 1);
870 match &choices.first().unwrap().message {
871 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
872 _ => panic!("Expected assistant message"),
873 }
874 }
875
876 #[test]
877 fn test_deserialize_deepseek_response() {
878 let data = r#"{
879 "choices":[{
880 "finish_reason": "stop",
881 "index": 0,
882 "logprobs": null,
883 "message":{"role":"assistant","content":"Hello, world!"}
884 }],
885 "usage": {
886 "completion_tokens": 0,
887 "prompt_tokens": 0,
888 "prompt_cache_hit_tokens": 0,
889 "prompt_cache_miss_tokens": 0,
890 "total_tokens": 0
891 }
892 }"#;
893
894 let jd = &mut serde_json::Deserializer::from_str(data);
895 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
896 match result {
897 Ok(response) => match &response.choices.first().unwrap().message {
898 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
899 _ => panic!("Expected assistant message"),
900 },
901 Err(err) => {
902 panic!("Deserialization error at {}: {}", err.path(), err);
903 }
904 }
905 }
906
907 #[test]
908 fn test_deserialize_example_response() {
909 let data = r#"
910 {
911 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
912 "object": "chat.completion",
913 "created": 0,
914 "model": "deepseek-chat",
915 "choices": [
916 {
917 "index": 0,
918 "message": {
919 "role": "assistant",
920 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
921 },
922 "logprobs": null,
923 "finish_reason": "stop"
924 }
925 ],
926 "usage": {
927 "prompt_tokens": 13,
928 "completion_tokens": 32,
929 "total_tokens": 45,
930 "prompt_tokens_details": {
931 "cached_tokens": 0
932 },
933 "prompt_cache_hit_tokens": 0,
934 "prompt_cache_miss_tokens": 13
935 },
936 "system_fingerprint": "fp_4b6881f2c5"
937 }
938 "#;
939 let jd = &mut serde_json::Deserializer::from_str(data);
940 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
941
942 match result {
943 Ok(response) => match &response.choices.first().unwrap().message {
944 Message::Assistant { content, .. } => assert_eq!(
945 content,
946 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
947 ),
948 _ => panic!("Expected assistant message"),
949 },
950 Err(err) => {
951 panic!("Deserialization error at {}: {}", err.path(), err);
952 }
953 }
954 }
955
956 #[test]
957 fn test_serialize_deserialize_tool_call_message() {
958 let tool_call_choice_json = r#"
959 {
960 "finish_reason": "tool_calls",
961 "index": 0,
962 "logprobs": null,
963 "message": {
964 "content": "",
965 "role": "assistant",
966 "tool_calls": [
967 {
968 "function": {
969 "arguments": "{\"x\":2,\"y\":5}",
970 "name": "subtract"
971 },
972 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
973 "index": 0,
974 "type": "function"
975 }
976 ]
977 }
978 }
979 "#;
980
981 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
982
983 let expected_choice: Choice = Choice {
984 finish_reason: "tool_calls".to_string(),
985 index: 0,
986 logprobs: None,
987 message: Message::Assistant {
988 content: "".to_string(),
989 name: None,
990 tool_calls: vec![ToolCall {
991 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
992 function: Function {
993 name: "subtract".to_string(),
994 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
995 },
996 index: 0,
997 r#type: ToolType::Function,
998 }],
999 reasoning_content: None,
1000 },
1001 };
1002
1003 assert_eq!(choice, expected_choice);
1004 }
1005}