1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51
52 const VERIFY_PATH: &'static str = "/user/balance";
53
54 fn build<H>(
55 _: &crate::client::ClientBuilder<
56 Self::Builder,
57 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58 H,
59 >,
60 ) -> http_client::Result<Self> {
61 Ok(Self)
62 }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66 type Completion = Capable<CompletionModel<H>>;
67 type Embeddings = Nothing;
68 type Transcription = Nothing;
69 #[cfg(feature = "image")]
70 type ImageGeneration = Nothing;
71 #[cfg(feature = "audio")]
72 type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78 type Output = DeepSeekExt;
79 type ApiKey = DeepSeekApiKey;
80
81 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88 type Input = DeepSeekApiKey;
89
90 fn from_env() -> Self {
92 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93 Self::new(&api_key).unwrap()
94 }
95
96 fn from_val(input: Self::Input) -> Self {
97 Self::new(input).unwrap()
98 }
99}
100
101#[derive(Debug, Deserialize)]
102struct ApiErrorResponse {
103 message: String,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(untagged)]
108enum ApiResponse<T> {
109 Ok(T),
110 Err(ApiErrorResponse),
111}
112
113impl From<ApiErrorResponse> for CompletionError {
114 fn from(err: ApiErrorResponse) -> Self {
115 CompletionError::ProviderError(err.message)
116 }
117}
118
119#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct CompletionResponse {
122 pub choices: Vec<Choice>,
124 pub usage: Usage,
125 }
127
128#[derive(Clone, Debug, Serialize, Deserialize, Default)]
129pub struct Usage {
130 pub completion_tokens: u32,
131 pub prompt_tokens: u32,
132 pub prompt_cache_hit_tokens: u32,
133 pub prompt_cache_miss_tokens: u32,
134 pub total_tokens: u32,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub completion_tokens_details: Option<CompletionTokensDetails>,
137 #[serde(skip_serializing_if = "Option::is_none")]
138 pub prompt_tokens_details: Option<PromptTokensDetails>,
139}
140
141impl Usage {
142 fn new() -> Self {
143 Self {
144 completion_tokens: 0,
145 prompt_tokens: 0,
146 prompt_cache_hit_tokens: 0,
147 prompt_cache_miss_tokens: 0,
148 total_tokens: 0,
149 completion_tokens_details: None,
150 prompt_tokens_details: None,
151 }
152 }
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize, Default)]
156pub struct CompletionTokensDetails {
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pub reasoning_tokens: Option<u32>,
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct PromptTokensDetails {
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub cached_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
168pub struct Choice {
169 pub index: usize,
170 pub message: Message,
171 pub logprobs: Option<serde_json::Value>,
172 pub finish_reason: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(tag = "role", rename_all = "lowercase")]
177pub enum Message {
178 System {
179 content: String,
180 #[serde(skip_serializing_if = "Option::is_none")]
181 name: Option<String>,
182 },
183 User {
184 content: String,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 name: Option<String>,
187 },
188 Assistant {
189 content: String,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 name: Option<String>,
192 #[serde(
193 default,
194 deserialize_with = "json_utils::null_or_vec",
195 skip_serializing_if = "Vec::is_empty"
196 )]
197 tool_calls: Vec<ToolCall>,
198 },
199 #[serde(rename = "tool")]
200 ToolResult {
201 tool_call_id: String,
202 content: String,
203 },
204}
205
206impl Message {
207 pub fn system(content: &str) -> Self {
208 Message::System {
209 content: content.to_owned(),
210 name: None,
211 }
212 }
213}
214
215impl From<message::ToolResult> for Message {
216 fn from(tool_result: message::ToolResult) -> Self {
217 let content = match tool_result.content.first() {
218 message::ToolResultContent::Text(text) => text.text,
219 message::ToolResultContent::Image(_) => String::from("[Image]"),
220 };
221
222 Message::ToolResult {
223 tool_call_id: tool_result.id,
224 content,
225 }
226 }
227}
228
229impl From<message::ToolCall> for ToolCall {
230 fn from(tool_call: message::ToolCall) -> Self {
231 Self {
232 id: tool_call.id,
233 index: 0,
235 r#type: ToolType::Function,
236 function: Function {
237 name: tool_call.function.name,
238 arguments: tool_call.function.arguments,
239 },
240 }
241 }
242}
243
244impl TryFrom<message::Message> for Vec<Message> {
245 type Error = message::MessageError;
246
247 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248 match message {
249 message::Message::User { content } => {
250 let mut messages = vec![];
252
253 let tool_results = content
254 .clone()
255 .into_iter()
256 .filter_map(|content| match content {
257 message::UserContent::ToolResult(tool_result) => {
258 Some(Message::from(tool_result))
259 }
260 _ => None,
261 })
262 .collect::<Vec<_>>();
263
264 messages.extend(tool_results);
265
266 let text_messages = content
268 .into_iter()
269 .filter_map(|content| match content {
270 message::UserContent::Text(text) => Some(Message::User {
271 content: text.text,
272 name: None,
273 }),
274 message::UserContent::Document(Document {
275 data:
276 DocumentSourceKind::Base64(content)
277 | DocumentSourceKind::String(content),
278 ..
279 }) => Some(Message::User {
280 content,
281 name: None,
282 }),
283 _ => None,
284 })
285 .collect::<Vec<_>>();
286 messages.extend(text_messages);
287
288 Ok(messages)
289 }
290 message::Message::Assistant { content, .. } => {
291 let mut messages: Vec<Message> = vec![];
292
293 let text_content = content
295 .clone()
296 .into_iter()
297 .filter_map(|content| match content {
298 message::AssistantContent::Text(text) => Some(Message::Assistant {
299 content: text.text,
300 name: None,
301 tool_calls: vec![],
302 }),
303 _ => None,
304 })
305 .collect::<Vec<_>>();
306
307 messages.extend(text_content);
308
309 let tool_calls = content
311 .clone()
312 .into_iter()
313 .filter_map(|content| match content {
314 message::AssistantContent::ToolCall(tool_call) => {
315 Some(ToolCall::from(tool_call))
316 }
317 _ => None,
318 })
319 .collect::<Vec<_>>();
320
321 if !tool_calls.is_empty() {
323 messages.push(Message::Assistant {
324 content: "".to_string(),
325 name: None,
326 tool_calls,
327 });
328 }
329
330 Ok(messages)
331 }
332 }
333 }
334}
335
336#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
337pub struct ToolCall {
338 pub id: String,
339 pub index: usize,
340 #[serde(default)]
341 pub r#type: ToolType,
342 pub function: Function,
343}
344
345#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
346pub struct Function {
347 pub name: String,
348 #[serde(with = "json_utils::stringified_json")]
349 pub arguments: serde_json::Value,
350}
351
352#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
353#[serde(rename_all = "lowercase")]
354pub enum ToolType {
355 #[default]
356 Function,
357}
358
359#[derive(Clone, Debug, Deserialize, Serialize)]
360pub struct ToolDefinition {
361 pub r#type: String,
362 pub function: completion::ToolDefinition,
363}
364
365impl From<crate::completion::ToolDefinition> for ToolDefinition {
366 fn from(tool: crate::completion::ToolDefinition) -> Self {
367 Self {
368 r#type: "function".into(),
369 function: tool,
370 }
371 }
372}
373
374impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
375 type Error = CompletionError;
376
377 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
378 let choice = response.choices.first().ok_or_else(|| {
379 CompletionError::ResponseError("Response contained no choices".to_owned())
380 })?;
381 let content = match &choice.message {
382 Message::Assistant {
383 content,
384 tool_calls,
385 ..
386 } => {
387 let mut content = if content.trim().is_empty() {
388 vec![]
389 } else {
390 vec![completion::AssistantContent::text(content)]
391 };
392
393 content.extend(
394 tool_calls
395 .iter()
396 .map(|call| {
397 completion::AssistantContent::tool_call(
398 &call.id,
399 &call.function.name,
400 call.function.arguments.clone(),
401 )
402 })
403 .collect::<Vec<_>>(),
404 );
405 Ok(content)
406 }
407 _ => Err(CompletionError::ResponseError(
408 "Response did not contain a valid message or tool call".into(),
409 )),
410 }?;
411
412 let choice = OneOrMany::many(content).map_err(|_| {
413 CompletionError::ResponseError(
414 "Response contained no message or tool call (empty)".to_owned(),
415 )
416 })?;
417
418 let usage = completion::Usage {
419 input_tokens: response.usage.prompt_tokens as u64,
420 output_tokens: response.usage.completion_tokens as u64,
421 total_tokens: response.usage.total_tokens as u64,
422 };
423
424 Ok(completion::CompletionResponse {
425 choice,
426 usage,
427 raw_response: response,
428 })
429 }
430}
431
432#[derive(Debug, Serialize, Deserialize)]
433pub(super) struct DeepseekCompletionRequest {
434 model: String,
435 pub messages: Vec<Message>,
436 #[serde(flatten, skip_serializing_if = "Option::is_none")]
437 temperature: Option<f64>,
438 #[serde(skip_serializing_if = "Vec::is_empty")]
439 tools: Vec<ToolDefinition>,
440 #[serde(flatten, skip_serializing_if = "Option::is_none")]
441 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
442 #[serde(flatten, skip_serializing_if = "Option::is_none")]
443 pub additional_params: Option<serde_json::Value>,
444}
445
446impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
447 type Error = CompletionError;
448
449 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
450 let mut full_history: Vec<Message> = match &req.preamble {
451 Some(preamble) => vec![Message::system(preamble)],
452 None => vec![],
453 };
454
455 if let Some(docs) = req.normalized_documents() {
456 let docs: Vec<Message> = docs.try_into()?;
457 full_history.extend(docs);
458 }
459
460 let chat_history: Vec<Message> = req
461 .chat_history
462 .clone()
463 .into_iter()
464 .map(|message| message.try_into())
465 .collect::<Result<Vec<Vec<Message>>, _>>()?
466 .into_iter()
467 .flatten()
468 .collect();
469
470 full_history.extend(chat_history);
471
472 let tool_choice = req
473 .tool_choice
474 .clone()
475 .map(crate::providers::openrouter::ToolChoice::try_from)
476 .transpose()?;
477
478 Ok(Self {
479 model: model.to_string(),
480 messages: full_history,
481 temperature: req.temperature,
482 tools: req
483 .tools
484 .clone()
485 .into_iter()
486 .map(ToolDefinition::from)
487 .collect::<Vec<_>>(),
488 tool_choice,
489 additional_params: req.additional_params,
490 })
491 }
492}
493
494#[derive(Clone)]
496pub struct CompletionModel<T = reqwest::Client> {
497 pub client: Client<T>,
498 pub model: String,
499}
500
501impl<T> completion::CompletionModel for CompletionModel<T>
502where
503 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
504{
505 type Response = CompletionResponse;
506 type StreamingResponse = StreamingCompletionResponse;
507
508 type Client = Client<T>;
509
510 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
511 Self {
512 client: client.clone(),
513 model: model.into().to_string(),
514 }
515 }
516
517 #[cfg_attr(feature = "worker", worker::send)]
518 async fn completion(
519 &self,
520 completion_request: CompletionRequest,
521 ) -> Result<
522 completion::CompletionResponse<CompletionResponse>,
523 crate::completion::CompletionError,
524 > {
525 let span = if tracing::Span::current().is_disabled() {
526 info_span!(
527 target: "rig::completions",
528 "chat",
529 gen_ai.operation.name = "chat",
530 gen_ai.provider.name = "deepseek",
531 gen_ai.request.model = self.model,
532 gen_ai.system_instructions = tracing::field::Empty,
533 gen_ai.response.id = tracing::field::Empty,
534 gen_ai.response.model = tracing::field::Empty,
535 gen_ai.usage.output_tokens = tracing::field::Empty,
536 gen_ai.usage.input_tokens = tracing::field::Empty,
537 )
538 } else {
539 tracing::Span::current()
540 };
541
542 span.record("gen_ai.system_instructions", &completion_request.preamble);
543
544 let request =
545 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
546
547 if enabled!(Level::TRACE) {
548 tracing::trace!(target: "rig::completions",
549 "DeepSeek completion request: {}",
550 serde_json::to_string_pretty(&request)?
551 );
552 }
553
554 let body = serde_json::to_vec(&request)?;
555 let req = self
556 .client
557 .post("/chat/completions")?
558 .body(body)
559 .map_err(|e| CompletionError::HttpError(e.into()))?;
560
561 async move {
562 let response = self.client.send::<_, Bytes>(req).await?;
563 let status = response.status();
564 let response_body = response.into_body().into_future().await?.to_vec();
565
566 if status.is_success() {
567 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
568 ApiResponse::Ok(response) => {
569 let span = tracing::Span::current();
570 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
571 span.record(
572 "gen_ai.usage.output_tokens",
573 response.usage.completion_tokens,
574 );
575 if enabled!(Level::TRACE) {
576 tracing::trace!(target: "rig::completions",
577 "DeepSeek completion response: {}",
578 serde_json::to_string_pretty(&response)?
579 );
580 }
581 response.try_into()
582 }
583 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
584 }
585 } else {
586 Err(CompletionError::ProviderError(
587 String::from_utf8_lossy(&response_body).to_string(),
588 ))
589 }
590 }
591 .instrument(span)
592 .await
593 }
594
595 #[cfg_attr(feature = "worker", worker::send)]
596 async fn stream(
597 &self,
598 completion_request: CompletionRequest,
599 ) -> Result<
600 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
601 CompletionError,
602 > {
603 let preamble = completion_request.preamble.clone();
604 let mut request =
605 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
606
607 let params = json_utils::merge(
608 request.additional_params.unwrap_or(serde_json::json!({})),
609 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
610 );
611
612 request.additional_params = Some(params);
613
614 if enabled!(Level::TRACE) {
615 tracing::trace!(target: "rig::completions",
616 "DeepSeek streaming completion request: {}",
617 serde_json::to_string_pretty(&request)?
618 );
619 }
620
621 let body = serde_json::to_vec(&request)?;
622
623 let req = self
624 .client
625 .post("/chat/completions")?
626 .body(body)
627 .map_err(|e| CompletionError::HttpError(e.into()))?;
628
629 let span = if tracing::Span::current().is_disabled() {
630 info_span!(
631 target: "rig::completions",
632 "chat_streaming",
633 gen_ai.operation.name = "chat_streaming",
634 gen_ai.provider.name = "deepseek",
635 gen_ai.request.model = self.model,
636 gen_ai.system_instructions = preamble,
637 gen_ai.response.id = tracing::field::Empty,
638 gen_ai.response.model = tracing::field::Empty,
639 gen_ai.usage.output_tokens = tracing::field::Empty,
640 gen_ai.usage.input_tokens = tracing::field::Empty,
641 )
642 } else {
643 tracing::Span::current()
644 };
645
646 tracing::Instrument::instrument(
647 send_compatible_streaming_request(self.client.clone(), req),
648 span,
649 )
650 .await
651 }
652}
653
654#[derive(Deserialize, Debug)]
655pub struct StreamingDelta {
656 #[serde(default)]
657 content: Option<String>,
658 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
659 tool_calls: Vec<StreamingToolCall>,
660 reasoning_content: Option<String>,
661}
662
663#[derive(Deserialize, Debug)]
664struct StreamingChoice {
665 delta: StreamingDelta,
666}
667
668#[derive(Deserialize, Debug)]
669struct StreamingCompletionChunk {
670 choices: Vec<StreamingChoice>,
671 usage: Option<Usage>,
672}
673
674#[derive(Clone, Deserialize, Serialize, Debug)]
675pub struct StreamingCompletionResponse {
676 pub usage: Usage,
677}
678
679impl GetTokenUsage for StreamingCompletionResponse {
680 fn token_usage(&self) -> Option<crate::completion::Usage> {
681 let mut usage = crate::completion::Usage::new();
682 usage.input_tokens = self.usage.prompt_tokens as u64;
683 usage.output_tokens = self.usage.completion_tokens as u64;
684 usage.total_tokens = self.usage.total_tokens as u64;
685
686 Some(usage)
687 }
688}
689
690pub async fn send_compatible_streaming_request<T>(
691 http_client: T,
692 req: Request<Vec<u8>>,
693) -> Result<
694 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
695 CompletionError,
696>
697where
698 T: HttpClientExt + Clone + 'static,
699{
700 let span = tracing::Span::current();
701 let mut event_source = GenericEventSource::new(http_client, req);
702
703 let stream = stream! {
704 let mut final_usage = Usage::new();
705 let mut text_response = String::new();
706 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
707
708 while let Some(event_result) = event_source.next().await {
709 match event_result {
710 Ok(Event::Open) => {
711 tracing::trace!("SSE connection opened");
712 continue;
713 }
714 Ok(Event::Message(message)) => {
715 if message.data.trim().is_empty() || message.data == "[DONE]" {
716 continue;
717 }
718
719 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
720 let Ok(data) = parsed else {
721 let err = parsed.unwrap_err();
722 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
723 continue;
724 };
725
726 if let Some(choice) = data.choices.first() {
727 let delta = &choice.delta;
728
729 if !delta.tool_calls.is_empty() {
730 for tool_call in &delta.tool_calls {
731 let function = &tool_call.function;
732
733 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
735 && empty_or_none(&function.arguments)
736 {
737 let id = tool_call.id.clone().unwrap_or_default();
738 let name = function.name.clone().unwrap();
739 calls.insert(tool_call.index, (id, name, String::new()));
740 }
741 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
743 && let Some(arguments) = &function.arguments
744 && !arguments.is_empty()
745 {
746 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
747 let combined = format!("{}{}", existing_args, arguments);
748 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
749 } else {
750 tracing::debug!("Partial tool call received but tool call was never started.");
751 }
752 }
753 else {
755 let id = tool_call.id.clone().unwrap_or_default();
756 let name = function.name.clone().unwrap_or_default();
757 let arguments_str = function.arguments.clone().unwrap_or_default();
758
759 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
760 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
761 continue;
762 };
763
764 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
765 id,
766 name,
767 arguments: arguments_json,
768 call_id: None,
769 });
770 }
771 }
772 }
773
774 if let Some(content) = &delta.reasoning_content {
776 yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
777 reasoning: content.to_string(),
778 id: None,
779 signature: None,
780 });
781 }
782
783 if let Some(content) = &delta.content {
784 text_response += content;
785 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
786 }
787 }
788
789 if let Some(usage) = data.usage {
790 final_usage = usage.clone();
791 }
792 }
793 Err(crate::http_client::Error::StreamEnded) => {
794 break;
795 }
796 Err(err) => {
797 tracing::error!(?err, "SSE error");
798 yield Err(CompletionError::ResponseError(err.to_string()));
799 break;
800 }
801 }
802 }
803
804 event_source.close();
805
806 let mut tool_calls = Vec::new();
807 for (index, (id, name, arguments)) in calls {
809 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
810 continue;
811 };
812
813 tool_calls.push(ToolCall {
814 id: id.clone(),
815 index,
816 r#type: ToolType::Function,
817 function: Function {
818 name: name.clone(),
819 arguments: arguments_json.clone()
820 }
821 });
822 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
823 id,
824 name,
825 arguments: arguments_json,
826 call_id: None,
827 });
828 }
829
830 let message = Message::Assistant {
831 content: text_response,
832 name: None,
833 tool_calls
834 };
835
836 span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
837
838 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
839 StreamingCompletionResponse { usage: final_usage.clone() }
840 ));
841 };
842
843 Ok(crate::streaming::StreamingCompletionResponse::stream(
844 Box::pin(stream),
845 ))
846}
847
848pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
852pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
853
854#[cfg(test)]
856mod tests {
857
858 use super::*;
859
860 #[test]
861 fn test_deserialize_vec_choice() {
862 let data = r#"[{
863 "finish_reason": "stop",
864 "index": 0,
865 "logprobs": null,
866 "message":{"role":"assistant","content":"Hello, world!"}
867 }]"#;
868
869 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
870 assert_eq!(choices.len(), 1);
871 match &choices.first().unwrap().message {
872 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
873 _ => panic!("Expected assistant message"),
874 }
875 }
876
877 #[test]
878 fn test_deserialize_deepseek_response() {
879 let data = r#"{
880 "choices":[{
881 "finish_reason": "stop",
882 "index": 0,
883 "logprobs": null,
884 "message":{"role":"assistant","content":"Hello, world!"}
885 }],
886 "usage": {
887 "completion_tokens": 0,
888 "prompt_tokens": 0,
889 "prompt_cache_hit_tokens": 0,
890 "prompt_cache_miss_tokens": 0,
891 "total_tokens": 0
892 }
893 }"#;
894
895 let jd = &mut serde_json::Deserializer::from_str(data);
896 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
897 match result {
898 Ok(response) => match &response.choices.first().unwrap().message {
899 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
900 _ => panic!("Expected assistant message"),
901 },
902 Err(err) => {
903 panic!("Deserialization error at {}: {}", err.path(), err);
904 }
905 }
906 }
907
908 #[test]
909 fn test_deserialize_example_response() {
910 let data = r#"
911 {
912 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
913 "object": "chat.completion",
914 "created": 0,
915 "model": "deepseek-chat",
916 "choices": [
917 {
918 "index": 0,
919 "message": {
920 "role": "assistant",
921 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
922 },
923 "logprobs": null,
924 "finish_reason": "stop"
925 }
926 ],
927 "usage": {
928 "prompt_tokens": 13,
929 "completion_tokens": 32,
930 "total_tokens": 45,
931 "prompt_tokens_details": {
932 "cached_tokens": 0
933 },
934 "prompt_cache_hit_tokens": 0,
935 "prompt_cache_miss_tokens": 13
936 },
937 "system_fingerprint": "fp_4b6881f2c5"
938 }
939 "#;
940 let jd = &mut serde_json::Deserializer::from_str(data);
941 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
942
943 match result {
944 Ok(response) => match &response.choices.first().unwrap().message {
945 Message::Assistant { content, .. } => assert_eq!(
946 content,
947 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
948 ),
949 _ => panic!("Expected assistant message"),
950 },
951 Err(err) => {
952 panic!("Deserialization error at {}: {}", err.path(), err);
953 }
954 }
955 }
956
957 #[test]
958 fn test_serialize_deserialize_tool_call_message() {
959 let tool_call_choice_json = r#"
960 {
961 "finish_reason": "tool_calls",
962 "index": 0,
963 "logprobs": null,
964 "message": {
965 "content": "",
966 "role": "assistant",
967 "tool_calls": [
968 {
969 "function": {
970 "arguments": "{\"x\":2,\"y\":5}",
971 "name": "subtract"
972 },
973 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
974 "index": 0,
975 "type": "function"
976 }
977 ]
978 }
979 }
980 "#;
981
982 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
983
984 let expected_choice: Choice = Choice {
985 finish_reason: "tool_calls".to_string(),
986 index: 0,
987 logprobs: None,
988 message: Message::Assistant {
989 content: "".to_string(),
990 name: None,
991 tool_calls: vec![ToolCall {
992 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
993 function: Function {
994 name: "subtract".to_string(),
995 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
996 },
997 index: 0,
998 r#type: ToolType::Function,
999 }],
1000 },
1001 };
1002
1003 assert_eq!(choice, expected_choice);
1004 }
1005}