1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51
52 const VERIFY_PATH: &'static str = "/user/balance";
53
54 fn build<H>(
55 _: &crate::client::ClientBuilder<
56 Self::Builder,
57 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58 H,
59 >,
60 ) -> http_client::Result<Self> {
61 Ok(Self)
62 }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66 type Completion = Capable<CompletionModel<H>>;
67 type Embeddings = Nothing;
68 type Transcription = Nothing;
69 #[cfg(feature = "image")]
70 type ImageGeneration = Nothing;
71 #[cfg(feature = "audio")]
72 type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78 type Output = DeepSeekExt;
79 type ApiKey = DeepSeekApiKey;
80
81 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88 type Input = DeepSeekApiKey;
89
90 fn from_env() -> Self {
92 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93 let mut client_builder = Self::builder();
94 client_builder.headers_mut().insert(
95 http::header::CONTENT_TYPE,
96 http::HeaderValue::from_static("application/json"),
97 );
98 let client_builder = client_builder.api_key(&api_key);
99 client_builder.build().unwrap()
100 }
101
102 fn from_val(input: Self::Input) -> Self {
103 Self::new(input).unwrap()
104 }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109 message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115 Ok(T),
116 Err(ApiErrorResponse),
117}
118
119impl From<ApiErrorResponse> for CompletionError {
120 fn from(err: ApiErrorResponse) -> Self {
121 CompletionError::ProviderError(err.message)
122 }
123}
124
125#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct CompletionResponse {
128 pub choices: Vec<Choice>,
130 pub usage: Usage,
131 }
133
134#[derive(Clone, Debug, Serialize, Deserialize, Default)]
135pub struct Usage {
136 pub completion_tokens: u32,
137 pub prompt_tokens: u32,
138 pub prompt_cache_hit_tokens: u32,
139 pub prompt_cache_miss_tokens: u32,
140 pub total_tokens: u32,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub completion_tokens_details: Option<CompletionTokensDetails>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 pub prompt_tokens_details: Option<PromptTokensDetails>,
145}
146
147impl Usage {
148 fn new() -> Self {
149 Self {
150 completion_tokens: 0,
151 prompt_tokens: 0,
152 prompt_cache_hit_tokens: 0,
153 prompt_cache_miss_tokens: 0,
154 total_tokens: 0,
155 completion_tokens_details: None,
156 prompt_tokens_details: None,
157 }
158 }
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct CompletionTokensDetails {
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub reasoning_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, Default)]
168pub struct PromptTokensDetails {
169 #[serde(skip_serializing_if = "Option::is_none")]
170 pub cached_tokens: Option<u32>,
171}
172
173#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
174pub struct Choice {
175 pub index: usize,
176 pub message: Message,
177 pub logprobs: Option<serde_json::Value>,
178 pub finish_reason: String,
179}
180
181#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
182#[serde(tag = "role", rename_all = "lowercase")]
183pub enum Message {
184 System {
185 content: String,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 name: Option<String>,
188 },
189 User {
190 content: String,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 name: Option<String>,
193 },
194 Assistant {
195 content: String,
196 #[serde(skip_serializing_if = "Option::is_none")]
197 name: Option<String>,
198 #[serde(
199 default,
200 deserialize_with = "json_utils::null_or_vec",
201 skip_serializing_if = "Vec::is_empty"
202 )]
203 tool_calls: Vec<ToolCall>,
204 },
205 #[serde(rename = "tool")]
206 ToolResult {
207 tool_call_id: String,
208 content: String,
209 },
210}
211
212impl Message {
213 pub fn system(content: &str) -> Self {
214 Message::System {
215 content: content.to_owned(),
216 name: None,
217 }
218 }
219}
220
221impl From<message::ToolResult> for Message {
222 fn from(tool_result: message::ToolResult) -> Self {
223 let content = match tool_result.content.first() {
224 message::ToolResultContent::Text(text) => text.text,
225 message::ToolResultContent::Image(_) => String::from("[Image]"),
226 };
227
228 Message::ToolResult {
229 tool_call_id: tool_result.id,
230 content,
231 }
232 }
233}
234
235impl From<message::ToolCall> for ToolCall {
236 fn from(tool_call: message::ToolCall) -> Self {
237 Self {
238 id: tool_call.id,
239 index: 0,
241 r#type: ToolType::Function,
242 function: Function {
243 name: tool_call.function.name,
244 arguments: tool_call.function.arguments,
245 },
246 }
247 }
248}
249
250impl TryFrom<message::Message> for Vec<Message> {
251 type Error = message::MessageError;
252
253 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
254 match message {
255 message::Message::User { content } => {
256 let mut messages = vec![];
258
259 let tool_results = content
260 .clone()
261 .into_iter()
262 .filter_map(|content| match content {
263 message::UserContent::ToolResult(tool_result) => {
264 Some(Message::from(tool_result))
265 }
266 _ => None,
267 })
268 .collect::<Vec<_>>();
269
270 messages.extend(tool_results);
271
272 let text_messages = content
274 .into_iter()
275 .filter_map(|content| match content {
276 message::UserContent::Text(text) => Some(Message::User {
277 content: text.text,
278 name: None,
279 }),
280 message::UserContent::Document(Document {
281 data:
282 DocumentSourceKind::Base64(content)
283 | DocumentSourceKind::String(content),
284 ..
285 }) => Some(Message::User {
286 content,
287 name: None,
288 }),
289 _ => None,
290 })
291 .collect::<Vec<_>>();
292 messages.extend(text_messages);
293
294 Ok(messages)
295 }
296 message::Message::Assistant { content, .. } => {
297 let mut messages: Vec<Message> = vec![];
298
299 let text_content = content
301 .clone()
302 .into_iter()
303 .filter_map(|content| match content {
304 message::AssistantContent::Text(text) => Some(Message::Assistant {
305 content: text.text,
306 name: None,
307 tool_calls: vec![],
308 }),
309 _ => None,
310 })
311 .collect::<Vec<_>>();
312
313 messages.extend(text_content);
314
315 let tool_calls = content
317 .clone()
318 .into_iter()
319 .filter_map(|content| match content {
320 message::AssistantContent::ToolCall(tool_call) => {
321 Some(ToolCall::from(tool_call))
322 }
323 _ => None,
324 })
325 .collect::<Vec<_>>();
326
327 if !tool_calls.is_empty() {
329 messages.push(Message::Assistant {
330 content: "".to_string(),
331 name: None,
332 tool_calls,
333 });
334 }
335
336 Ok(messages)
337 }
338 }
339 }
340}
341
342#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
343pub struct ToolCall {
344 pub id: String,
345 pub index: usize,
346 #[serde(default)]
347 pub r#type: ToolType,
348 pub function: Function,
349}
350
351#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
352pub struct Function {
353 pub name: String,
354 #[serde(with = "json_utils::stringified_json")]
355 pub arguments: serde_json::Value,
356}
357
358#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
359#[serde(rename_all = "lowercase")]
360pub enum ToolType {
361 #[default]
362 Function,
363}
364
365#[derive(Clone, Debug, Deserialize, Serialize)]
366pub struct ToolDefinition {
367 pub r#type: String,
368 pub function: completion::ToolDefinition,
369}
370
371impl From<crate::completion::ToolDefinition> for ToolDefinition {
372 fn from(tool: crate::completion::ToolDefinition) -> Self {
373 Self {
374 r#type: "function".into(),
375 function: tool,
376 }
377 }
378}
379
380impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
381 type Error = CompletionError;
382
383 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
384 let choice = response.choices.first().ok_or_else(|| {
385 CompletionError::ResponseError("Response contained no choices".to_owned())
386 })?;
387 let content = match &choice.message {
388 Message::Assistant {
389 content,
390 tool_calls,
391 ..
392 } => {
393 let mut content = if content.trim().is_empty() {
394 vec![]
395 } else {
396 vec![completion::AssistantContent::text(content)]
397 };
398
399 content.extend(
400 tool_calls
401 .iter()
402 .map(|call| {
403 completion::AssistantContent::tool_call(
404 &call.id,
405 &call.function.name,
406 call.function.arguments.clone(),
407 )
408 })
409 .collect::<Vec<_>>(),
410 );
411 Ok(content)
412 }
413 _ => Err(CompletionError::ResponseError(
414 "Response did not contain a valid message or tool call".into(),
415 )),
416 }?;
417
418 let choice = OneOrMany::many(content).map_err(|_| {
419 CompletionError::ResponseError(
420 "Response contained no message or tool call (empty)".to_owned(),
421 )
422 })?;
423
424 let usage = completion::Usage {
425 input_tokens: response.usage.prompt_tokens as u64,
426 output_tokens: response.usage.completion_tokens as u64,
427 total_tokens: response.usage.total_tokens as u64,
428 };
429
430 Ok(completion::CompletionResponse {
431 choice,
432 usage,
433 raw_response: response,
434 })
435 }
436}
437
438#[derive(Debug, Serialize, Deserialize)]
439pub(super) struct DeepseekCompletionRequest {
440 model: String,
441 pub messages: Vec<Message>,
442 #[serde(skip_serializing_if = "Option::is_none")]
443 temperature: Option<f64>,
444 #[serde(skip_serializing_if = "Vec::is_empty")]
445 tools: Vec<ToolDefinition>,
446 #[serde(skip_serializing_if = "Option::is_none")]
447 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
448 #[serde(flatten, skip_serializing_if = "Option::is_none")]
449 pub additional_params: Option<serde_json::Value>,
450}
451
452impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
453 type Error = CompletionError;
454
455 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
456 let mut full_history: Vec<Message> = match &req.preamble {
457 Some(preamble) => vec![Message::system(preamble)],
458 None => vec![],
459 };
460
461 if let Some(docs) = req.normalized_documents() {
462 let docs: Vec<Message> = docs.try_into()?;
463 full_history.extend(docs);
464 }
465
466 let chat_history: Vec<Message> = req
467 .chat_history
468 .clone()
469 .into_iter()
470 .map(|message| message.try_into())
471 .collect::<Result<Vec<Vec<Message>>, _>>()?
472 .into_iter()
473 .flatten()
474 .collect();
475
476 full_history.extend(chat_history);
477
478 let tool_choice = req
479 .tool_choice
480 .clone()
481 .map(crate::providers::openrouter::ToolChoice::try_from)
482 .transpose()?;
483
484 Ok(Self {
485 model: model.to_string(),
486 messages: full_history,
487 temperature: req.temperature,
488 tools: req
489 .tools
490 .clone()
491 .into_iter()
492 .map(ToolDefinition::from)
493 .collect::<Vec<_>>(),
494 tool_choice,
495 additional_params: req.additional_params,
496 })
497 }
498}
499
500#[derive(Clone)]
502pub struct CompletionModel<T = reqwest::Client> {
503 pub client: Client<T>,
504 pub model: String,
505}
506
507impl<T> completion::CompletionModel for CompletionModel<T>
508where
509 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
510{
511 type Response = CompletionResponse;
512 type StreamingResponse = StreamingCompletionResponse;
513
514 type Client = Client<T>;
515
516 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
517 Self {
518 client: client.clone(),
519 model: model.into().to_string(),
520 }
521 }
522
523 async fn completion(
524 &self,
525 completion_request: CompletionRequest,
526 ) -> Result<
527 completion::CompletionResponse<CompletionResponse>,
528 crate::completion::CompletionError,
529 > {
530 let span = if tracing::Span::current().is_disabled() {
531 info_span!(
532 target: "rig::completions",
533 "chat",
534 gen_ai.operation.name = "chat",
535 gen_ai.provider.name = "deepseek",
536 gen_ai.request.model = self.model,
537 gen_ai.system_instructions = tracing::field::Empty,
538 gen_ai.response.id = tracing::field::Empty,
539 gen_ai.response.model = tracing::field::Empty,
540 gen_ai.usage.output_tokens = tracing::field::Empty,
541 gen_ai.usage.input_tokens = tracing::field::Empty,
542 )
543 } else {
544 tracing::Span::current()
545 };
546
547 span.record("gen_ai.system_instructions", &completion_request.preamble);
548
549 let request =
550 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
551
552 if enabled!(Level::TRACE) {
553 tracing::trace!(target: "rig::completions",
554 "DeepSeek completion request: {}",
555 serde_json::to_string_pretty(&request)?
556 );
557 }
558
559 let body = serde_json::to_vec(&request)?;
560 let req = self
561 .client
562 .post("/chat/completions")?
563 .body(body)
564 .map_err(|e| CompletionError::HttpError(e.into()))?;
565
566 async move {
567 let response = self.client.send::<_, Bytes>(req).await?;
568 let status = response.status();
569 let response_body = response.into_body().into_future().await?.to_vec();
570
571 if status.is_success() {
572 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
573 ApiResponse::Ok(response) => {
574 let span = tracing::Span::current();
575 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
576 span.record(
577 "gen_ai.usage.output_tokens",
578 response.usage.completion_tokens,
579 );
580 if enabled!(Level::TRACE) {
581 tracing::trace!(target: "rig::completions",
582 "DeepSeek completion response: {}",
583 serde_json::to_string_pretty(&response)?
584 );
585 }
586 response.try_into()
587 }
588 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
589 }
590 } else {
591 Err(CompletionError::ProviderError(
592 String::from_utf8_lossy(&response_body).to_string(),
593 ))
594 }
595 }
596 .instrument(span)
597 .await
598 }
599
600 async fn stream(
601 &self,
602 completion_request: CompletionRequest,
603 ) -> Result<
604 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
605 CompletionError,
606 > {
607 let preamble = completion_request.preamble.clone();
608 let mut request =
609 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
610
611 let params = json_utils::merge(
612 request.additional_params.unwrap_or(serde_json::json!({})),
613 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
614 );
615
616 request.additional_params = Some(params);
617
618 if enabled!(Level::TRACE) {
619 tracing::trace!(target: "rig::completions",
620 "DeepSeek streaming completion request: {}",
621 serde_json::to_string_pretty(&request)?
622 );
623 }
624
625 let body = serde_json::to_vec(&request)?;
626
627 let req = self
628 .client
629 .post("/chat/completions")?
630 .body(body)
631 .map_err(|e| CompletionError::HttpError(e.into()))?;
632
633 let span = if tracing::Span::current().is_disabled() {
634 info_span!(
635 target: "rig::completions",
636 "chat_streaming",
637 gen_ai.operation.name = "chat_streaming",
638 gen_ai.provider.name = "deepseek",
639 gen_ai.request.model = self.model,
640 gen_ai.system_instructions = preamble,
641 gen_ai.response.id = tracing::field::Empty,
642 gen_ai.response.model = tracing::field::Empty,
643 gen_ai.usage.output_tokens = tracing::field::Empty,
644 gen_ai.usage.input_tokens = tracing::field::Empty,
645 )
646 } else {
647 tracing::Span::current()
648 };
649
650 tracing::Instrument::instrument(
651 send_compatible_streaming_request(self.client.clone(), req),
652 span,
653 )
654 .await
655 }
656}
657
658#[derive(Deserialize, Debug)]
659pub struct StreamingDelta {
660 #[serde(default)]
661 content: Option<String>,
662 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
663 tool_calls: Vec<StreamingToolCall>,
664 reasoning_content: Option<String>,
665}
666
667#[derive(Deserialize, Debug)]
668struct StreamingChoice {
669 delta: StreamingDelta,
670}
671
672#[derive(Deserialize, Debug)]
673struct StreamingCompletionChunk {
674 choices: Vec<StreamingChoice>,
675 usage: Option<Usage>,
676}
677
678#[derive(Clone, Deserialize, Serialize, Debug)]
679pub struct StreamingCompletionResponse {
680 pub usage: Usage,
681}
682
683impl GetTokenUsage for StreamingCompletionResponse {
684 fn token_usage(&self) -> Option<crate::completion::Usage> {
685 let mut usage = crate::completion::Usage::new();
686 usage.input_tokens = self.usage.prompt_tokens as u64;
687 usage.output_tokens = self.usage.completion_tokens as u64;
688 usage.total_tokens = self.usage.total_tokens as u64;
689
690 Some(usage)
691 }
692}
693
694pub async fn send_compatible_streaming_request<T>(
695 http_client: T,
696 req: Request<Vec<u8>>,
697) -> Result<
698 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
699 CompletionError,
700>
701where
702 T: HttpClientExt + Clone + 'static,
703{
704 let span = tracing::Span::current();
705 let mut event_source = GenericEventSource::new(http_client, req);
706
707 let stream = stream! {
708 let mut final_usage = Usage::new();
709 let mut text_response = String::new();
710 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
711
712 while let Some(event_result) = event_source.next().await {
713 match event_result {
714 Ok(Event::Open) => {
715 tracing::trace!("SSE connection opened");
716 continue;
717 }
718 Ok(Event::Message(message)) => {
719 if message.data.trim().is_empty() || message.data == "[DONE]" {
720 continue;
721 }
722
723 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
724 let Ok(data) = parsed else {
725 let err = parsed.unwrap_err();
726 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
727 continue;
728 };
729
730 if let Some(choice) = data.choices.first() {
731 let delta = &choice.delta;
732
733 if !delta.tool_calls.is_empty() {
734 for tool_call in &delta.tool_calls {
735 let function = &tool_call.function;
736
737 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
739 && empty_or_none(&function.arguments)
740 {
741 let id = tool_call.id.clone().unwrap_or_default();
742 let name = function.name.clone().unwrap();
743 calls.insert(tool_call.index, (id, name, String::new()));
744 }
745 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
747 && let Some(arguments) = &function.arguments
748 && !arguments.is_empty()
749 {
750 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
751 let combined = format!("{}{}", existing_args, arguments);
752 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
753 } else {
754 tracing::debug!("Partial tool call received but tool call was never started.");
755 }
756 }
757 else {
759 let id = tool_call.id.clone().unwrap_or_default();
760 let name = function.name.clone().unwrap_or_default();
761 let arguments_str = function.arguments.clone().unwrap_or_default();
762
763 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
764 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
765 continue;
766 };
767
768 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
769 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
770 ));
771 }
772 }
773 }
774
775 if let Some(content) = &delta.reasoning_content {
777 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
778 id: None,
779 reasoning: content.to_string()
780 });
781 }
782
783 if let Some(content) = &delta.content {
784 text_response += content;
785 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
786 }
787 }
788
789 if let Some(usage) = data.usage {
790 final_usage = usage.clone();
791 }
792 }
793 Err(crate::http_client::Error::StreamEnded) => {
794 break;
795 }
796 Err(err) => {
797 tracing::error!(?err, "SSE error");
798 yield Err(CompletionError::ResponseError(err.to_string()));
799 break;
800 }
801 }
802 }
803
804 event_source.close();
805
806 let mut tool_calls = Vec::new();
807 for (index, (id, name, arguments)) in calls {
809 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
810 continue;
811 };
812
813 tool_calls.push(ToolCall {
814 id: id.clone(),
815 index,
816 r#type: ToolType::Function,
817 function: Function {
818 name: name.clone(),
819 arguments: arguments_json.clone()
820 }
821 });
822 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
823 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
824 ));
825 }
826
827 let message = Message::Assistant {
828 content: text_response,
829 name: None,
830 tool_calls
831 };
832
833 span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
834
835 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
836 StreamingCompletionResponse { usage: final_usage.clone() }
837 ));
838 };
839
840 Ok(crate::streaming::StreamingCompletionResponse::stream(
841 Box::pin(stream),
842 ))
843}
844
845pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
849pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
850
851#[cfg(test)]
853mod tests {
854
855 use super::*;
856
857 #[test]
858 fn test_deserialize_vec_choice() {
859 let data = r#"[{
860 "finish_reason": "stop",
861 "index": 0,
862 "logprobs": null,
863 "message":{"role":"assistant","content":"Hello, world!"}
864 }]"#;
865
866 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
867 assert_eq!(choices.len(), 1);
868 match &choices.first().unwrap().message {
869 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
870 _ => panic!("Expected assistant message"),
871 }
872 }
873
874 #[test]
875 fn test_deserialize_deepseek_response() {
876 let data = r#"{
877 "choices":[{
878 "finish_reason": "stop",
879 "index": 0,
880 "logprobs": null,
881 "message":{"role":"assistant","content":"Hello, world!"}
882 }],
883 "usage": {
884 "completion_tokens": 0,
885 "prompt_tokens": 0,
886 "prompt_cache_hit_tokens": 0,
887 "prompt_cache_miss_tokens": 0,
888 "total_tokens": 0
889 }
890 }"#;
891
892 let jd = &mut serde_json::Deserializer::from_str(data);
893 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
894 match result {
895 Ok(response) => match &response.choices.first().unwrap().message {
896 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
897 _ => panic!("Expected assistant message"),
898 },
899 Err(err) => {
900 panic!("Deserialization error at {}: {}", err.path(), err);
901 }
902 }
903 }
904
905 #[test]
906 fn test_deserialize_example_response() {
907 let data = r#"
908 {
909 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
910 "object": "chat.completion",
911 "created": 0,
912 "model": "deepseek-chat",
913 "choices": [
914 {
915 "index": 0,
916 "message": {
917 "role": "assistant",
918 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
919 },
920 "logprobs": null,
921 "finish_reason": "stop"
922 }
923 ],
924 "usage": {
925 "prompt_tokens": 13,
926 "completion_tokens": 32,
927 "total_tokens": 45,
928 "prompt_tokens_details": {
929 "cached_tokens": 0
930 },
931 "prompt_cache_hit_tokens": 0,
932 "prompt_cache_miss_tokens": 13
933 },
934 "system_fingerprint": "fp_4b6881f2c5"
935 }
936 "#;
937 let jd = &mut serde_json::Deserializer::from_str(data);
938 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
939
940 match result {
941 Ok(response) => match &response.choices.first().unwrap().message {
942 Message::Assistant { content, .. } => assert_eq!(
943 content,
944 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
945 ),
946 _ => panic!("Expected assistant message"),
947 },
948 Err(err) => {
949 panic!("Deserialization error at {}: {}", err.path(), err);
950 }
951 }
952 }
953
954 #[test]
955 fn test_serialize_deserialize_tool_call_message() {
956 let tool_call_choice_json = r#"
957 {
958 "finish_reason": "tool_calls",
959 "index": 0,
960 "logprobs": null,
961 "message": {
962 "content": "",
963 "role": "assistant",
964 "tool_calls": [
965 {
966 "function": {
967 "arguments": "{\"x\":2,\"y\":5}",
968 "name": "subtract"
969 },
970 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
971 "index": 0,
972 "type": "function"
973 }
974 ]
975 }
976 }
977 "#;
978
979 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
980
981 let expected_choice: Choice = Choice {
982 finish_reason: "tool_calls".to_string(),
983 index: 0,
984 logprobs: None,
985 message: Message::Assistant {
986 content: "".to_string(),
987 name: None,
988 tool_calls: vec![ToolCall {
989 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
990 function: Function {
991 name: "subtract".to_string(),
992 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
993 },
994 index: 0,
995 r#type: ToolType::Function,
996 }],
997 },
998 };
999
1000 assert_eq!(choice, expected_choice);
1001 }
1002}