1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51
52 const VERIFY_PATH: &'static str = "/user/balance";
53
54 fn build<H>(
55 _: &crate::client::ClientBuilder<
56 Self::Builder,
57 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58 H,
59 >,
60 ) -> http_client::Result<Self> {
61 Ok(Self)
62 }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66 type Completion = Capable<CompletionModel<H>>;
67 type Embeddings = Nothing;
68 type Transcription = Nothing;
69 #[cfg(feature = "image")]
70 type ImageGeneration = Nothing;
71 #[cfg(feature = "audio")]
72 type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78 type Output = DeepSeekExt;
79 type ApiKey = DeepSeekApiKey;
80
81 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88 type Input = DeepSeekApiKey;
89
90 fn from_env() -> Self {
92 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93 Self::new(&api_key).unwrap()
94 }
95
96 fn from_val(input: Self::Input) -> Self {
97 Self::new(input).unwrap()
98 }
99}
100
101#[derive(Debug, Deserialize)]
102struct ApiErrorResponse {
103 message: String,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(untagged)]
108enum ApiResponse<T> {
109 Ok(T),
110 Err(ApiErrorResponse),
111}
112
113impl From<ApiErrorResponse> for CompletionError {
114 fn from(err: ApiErrorResponse) -> Self {
115 CompletionError::ProviderError(err.message)
116 }
117}
118
119#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct CompletionResponse {
122 pub choices: Vec<Choice>,
124 pub usage: Usage,
125 }
127
128#[derive(Clone, Debug, Serialize, Deserialize, Default)]
129pub struct Usage {
130 pub completion_tokens: u32,
131 pub prompt_tokens: u32,
132 pub prompt_cache_hit_tokens: u32,
133 pub prompt_cache_miss_tokens: u32,
134 pub total_tokens: u32,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub completion_tokens_details: Option<CompletionTokensDetails>,
137 #[serde(skip_serializing_if = "Option::is_none")]
138 pub prompt_tokens_details: Option<PromptTokensDetails>,
139}
140
141impl Usage {
142 fn new() -> Self {
143 Self {
144 completion_tokens: 0,
145 prompt_tokens: 0,
146 prompt_cache_hit_tokens: 0,
147 prompt_cache_miss_tokens: 0,
148 total_tokens: 0,
149 completion_tokens_details: None,
150 prompt_tokens_details: None,
151 }
152 }
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize, Default)]
156pub struct CompletionTokensDetails {
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pub reasoning_tokens: Option<u32>,
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct PromptTokensDetails {
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub cached_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
168pub struct Choice {
169 pub index: usize,
170 pub message: Message,
171 pub logprobs: Option<serde_json::Value>,
172 pub finish_reason: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(tag = "role", rename_all = "lowercase")]
177pub enum Message {
178 System {
179 content: String,
180 #[serde(skip_serializing_if = "Option::is_none")]
181 name: Option<String>,
182 },
183 User {
184 content: String,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 name: Option<String>,
187 },
188 Assistant {
189 content: String,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 name: Option<String>,
192 #[serde(
193 default,
194 deserialize_with = "json_utils::null_or_vec",
195 skip_serializing_if = "Vec::is_empty"
196 )]
197 tool_calls: Vec<ToolCall>,
198 },
199 #[serde(rename = "tool")]
200 ToolResult {
201 tool_call_id: String,
202 content: String,
203 },
204}
205
206impl Message {
207 pub fn system(content: &str) -> Self {
208 Message::System {
209 content: content.to_owned(),
210 name: None,
211 }
212 }
213}
214
215impl From<message::ToolResult> for Message {
216 fn from(tool_result: message::ToolResult) -> Self {
217 let content = match tool_result.content.first() {
218 message::ToolResultContent::Text(text) => text.text,
219 message::ToolResultContent::Image(_) => String::from("[Image]"),
220 };
221
222 Message::ToolResult {
223 tool_call_id: tool_result.id,
224 content,
225 }
226 }
227}
228
229impl From<message::ToolCall> for ToolCall {
230 fn from(tool_call: message::ToolCall) -> Self {
231 Self {
232 id: tool_call.id,
233 index: 0,
235 r#type: ToolType::Function,
236 function: Function {
237 name: tool_call.function.name,
238 arguments: tool_call.function.arguments,
239 },
240 }
241 }
242}
243
244impl TryFrom<message::Message> for Vec<Message> {
245 type Error = message::MessageError;
246
247 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248 match message {
249 message::Message::User { content } => {
250 let mut messages = vec![];
252
253 let tool_results = content
254 .clone()
255 .into_iter()
256 .filter_map(|content| match content {
257 message::UserContent::ToolResult(tool_result) => {
258 Some(Message::from(tool_result))
259 }
260 _ => None,
261 })
262 .collect::<Vec<_>>();
263
264 messages.extend(tool_results);
265
266 let text_messages = content
268 .into_iter()
269 .filter_map(|content| match content {
270 message::UserContent::Text(text) => Some(Message::User {
271 content: text.text,
272 name: None,
273 }),
274 message::UserContent::Document(Document {
275 data:
276 DocumentSourceKind::Base64(content)
277 | DocumentSourceKind::String(content),
278 ..
279 }) => Some(Message::User {
280 content,
281 name: None,
282 }),
283 _ => None,
284 })
285 .collect::<Vec<_>>();
286 messages.extend(text_messages);
287
288 Ok(messages)
289 }
290 message::Message::Assistant { content, .. } => {
291 let mut messages: Vec<Message> = vec![];
292
293 let text_content = content
295 .clone()
296 .into_iter()
297 .filter_map(|content| match content {
298 message::AssistantContent::Text(text) => Some(Message::Assistant {
299 content: text.text,
300 name: None,
301 tool_calls: vec![],
302 }),
303 _ => None,
304 })
305 .collect::<Vec<_>>();
306
307 messages.extend(text_content);
308
309 let tool_calls = content
311 .clone()
312 .into_iter()
313 .filter_map(|content| match content {
314 message::AssistantContent::ToolCall(tool_call) => {
315 Some(ToolCall::from(tool_call))
316 }
317 _ => None,
318 })
319 .collect::<Vec<_>>();
320
321 if !tool_calls.is_empty() {
323 messages.push(Message::Assistant {
324 content: "".to_string(),
325 name: None,
326 tool_calls,
327 });
328 }
329
330 Ok(messages)
331 }
332 }
333 }
334}
335
336#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
337pub struct ToolCall {
338 pub id: String,
339 pub index: usize,
340 #[serde(default)]
341 pub r#type: ToolType,
342 pub function: Function,
343}
344
345#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
346pub struct Function {
347 pub name: String,
348 #[serde(with = "json_utils::stringified_json")]
349 pub arguments: serde_json::Value,
350}
351
352#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
353#[serde(rename_all = "lowercase")]
354pub enum ToolType {
355 #[default]
356 Function,
357}
358
359#[derive(Clone, Debug, Deserialize, Serialize)]
360pub struct ToolDefinition {
361 pub r#type: String,
362 pub function: completion::ToolDefinition,
363}
364
365impl From<crate::completion::ToolDefinition> for ToolDefinition {
366 fn from(tool: crate::completion::ToolDefinition) -> Self {
367 Self {
368 r#type: "function".into(),
369 function: tool,
370 }
371 }
372}
373
374impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
375 type Error = CompletionError;
376
377 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
378 let choice = response.choices.first().ok_or_else(|| {
379 CompletionError::ResponseError("Response contained no choices".to_owned())
380 })?;
381 let content = match &choice.message {
382 Message::Assistant {
383 content,
384 tool_calls,
385 ..
386 } => {
387 let mut content = if content.trim().is_empty() {
388 vec![]
389 } else {
390 vec![completion::AssistantContent::text(content)]
391 };
392
393 content.extend(
394 tool_calls
395 .iter()
396 .map(|call| {
397 completion::AssistantContent::tool_call(
398 &call.id,
399 &call.function.name,
400 call.function.arguments.clone(),
401 )
402 })
403 .collect::<Vec<_>>(),
404 );
405 Ok(content)
406 }
407 _ => Err(CompletionError::ResponseError(
408 "Response did not contain a valid message or tool call".into(),
409 )),
410 }?;
411
412 let choice = OneOrMany::many(content).map_err(|_| {
413 CompletionError::ResponseError(
414 "Response contained no message or tool call (empty)".to_owned(),
415 )
416 })?;
417
418 let usage = completion::Usage {
419 input_tokens: response.usage.prompt_tokens as u64,
420 output_tokens: response.usage.completion_tokens as u64,
421 total_tokens: response.usage.total_tokens as u64,
422 };
423
424 Ok(completion::CompletionResponse {
425 choice,
426 usage,
427 raw_response: response,
428 })
429 }
430}
431
432#[derive(Debug, Serialize, Deserialize)]
433pub(super) struct DeepseekCompletionRequest {
434 model: String,
435 pub messages: Vec<Message>,
436 #[serde(flatten, skip_serializing_if = "Option::is_none")]
437 temperature: Option<f64>,
438 #[serde(skip_serializing_if = "Vec::is_empty")]
439 tools: Vec<ToolDefinition>,
440 #[serde(flatten, skip_serializing_if = "Option::is_none")]
441 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
442 #[serde(flatten, skip_serializing_if = "Option::is_none")]
443 pub additional_params: Option<serde_json::Value>,
444}
445
446impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
447 type Error = CompletionError;
448
449 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
450 let mut full_history: Vec<Message> = match &req.preamble {
451 Some(preamble) => vec![Message::system(preamble)],
452 None => vec![],
453 };
454
455 if let Some(docs) = req.normalized_documents() {
456 let docs: Vec<Message> = docs.try_into()?;
457 full_history.extend(docs);
458 }
459
460 let chat_history: Vec<Message> = req
461 .chat_history
462 .clone()
463 .into_iter()
464 .map(|message| message.try_into())
465 .collect::<Result<Vec<Vec<Message>>, _>>()?
466 .into_iter()
467 .flatten()
468 .collect();
469
470 full_history.extend(chat_history);
471
472 let tool_choice = req
473 .tool_choice
474 .clone()
475 .map(crate::providers::openrouter::ToolChoice::try_from)
476 .transpose()?;
477
478 Ok(Self {
479 model: model.to_string(),
480 messages: full_history,
481 temperature: req.temperature,
482 tools: req
483 .tools
484 .clone()
485 .into_iter()
486 .map(ToolDefinition::from)
487 .collect::<Vec<_>>(),
488 tool_choice,
489 additional_params: req.additional_params,
490 })
491 }
492}
493
494#[derive(Clone)]
496pub struct CompletionModel<T = reqwest::Client> {
497 pub client: Client<T>,
498 pub model: String,
499}
500
501impl<T> completion::CompletionModel for CompletionModel<T>
502where
503 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
504{
505 type Response = CompletionResponse;
506 type StreamingResponse = StreamingCompletionResponse;
507
508 type Client = Client<T>;
509
510 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
511 Self {
512 client: client.clone(),
513 model: model.into().to_string(),
514 }
515 }
516
517 #[cfg_attr(feature = "worker", worker::send)]
518 async fn completion(
519 &self,
520 completion_request: CompletionRequest,
521 ) -> Result<
522 completion::CompletionResponse<CompletionResponse>,
523 crate::completion::CompletionError,
524 > {
525 let preamble = completion_request.preamble.clone();
526 let request =
527 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
528
529 let span = if tracing::Span::current().is_disabled() {
530 info_span!(
531 target: "rig::completions",
532 "chat",
533 gen_ai.operation.name = "chat",
534 gen_ai.provider.name = "deepseek",
535 gen_ai.request.model = self.model,
536 gen_ai.system_instructions = preamble,
537 gen_ai.response.id = tracing::field::Empty,
538 gen_ai.response.model = tracing::field::Empty,
539 gen_ai.usage.output_tokens = tracing::field::Empty,
540 gen_ai.usage.input_tokens = tracing::field::Empty,
541 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
542 gen_ai.output.messages = tracing::field::Empty,
543 )
544 } else {
545 tracing::Span::current()
546 };
547
548 tracing::debug!("DeepSeek completion request: {request:?}");
549
550 let body = serde_json::to_vec(&request)?;
551 let req = self
552 .client
553 .post("/chat/completions")?
554 .body(body)
555 .map_err(|e| CompletionError::HttpError(e.into()))?;
556
557 async move {
558 let response = self.client.send::<_, Bytes>(req).await?;
559 let status = response.status();
560 let response_body = response.into_body().into_future().await?.to_vec();
561
562 if status.is_success() {
563 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
564 ApiResponse::Ok(response) => {
565 let span = tracing::Span::current();
566 span.record(
567 "gen_ai.output.messages",
568 serde_json::to_string(&response.choices).unwrap(),
569 );
570 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
571 span.record(
572 "gen_ai.usage.output_tokens",
573 response.usage.completion_tokens,
574 );
575 tracing::trace!(
576 target: "rig::completions",
577 "DeepSeek completion output: {}",
578 serde_json::to_string_pretty(&response_body)?
579 );
580 response.try_into()
581 }
582 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
583 }
584 } else {
585 Err(CompletionError::ProviderError(
586 String::from_utf8_lossy(&response_body).to_string(),
587 ))
588 }
589 }
590 .instrument(span)
591 .await
592 }
593
594 #[cfg_attr(feature = "worker", worker::send)]
595 async fn stream(
596 &self,
597 completion_request: CompletionRequest,
598 ) -> Result<
599 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
600 CompletionError,
601 > {
602 let preamble = completion_request.preamble.clone();
603 let mut request =
604 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
605
606 let params = json_utils::merge(
607 request.additional_params.unwrap_or(serde_json::json!({})),
608 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
609 );
610
611 request.additional_params = Some(params);
612
613 let body = serde_json::to_vec(&request)?;
614
615 let req = self
616 .client
617 .post("/chat/completions")?
618 .body(body)
619 .map_err(|e| CompletionError::HttpError(e.into()))?;
620
621 let span = if tracing::Span::current().is_disabled() {
622 info_span!(
623 target: "rig::completions",
624 "chat_streaming",
625 gen_ai.operation.name = "chat_streaming",
626 gen_ai.provider.name = "deepseek",
627 gen_ai.request.model = self.model,
628 gen_ai.system_instructions = preamble,
629 gen_ai.response.id = tracing::field::Empty,
630 gen_ai.response.model = tracing::field::Empty,
631 gen_ai.usage.output_tokens = tracing::field::Empty,
632 gen_ai.usage.input_tokens = tracing::field::Empty,
633 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
634 gen_ai.output.messages = tracing::field::Empty,
635 )
636 } else {
637 tracing::Span::current()
638 };
639
640 tracing::Instrument::instrument(
641 send_compatible_streaming_request(self.client.http_client().clone(), req),
642 span,
643 )
644 .await
645 }
646}
647
648#[derive(Deserialize, Debug)]
649pub struct StreamingDelta {
650 #[serde(default)]
651 content: Option<String>,
652 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
653 tool_calls: Vec<StreamingToolCall>,
654 reasoning_content: Option<String>,
655}
656
657#[derive(Deserialize, Debug)]
658struct StreamingChoice {
659 delta: StreamingDelta,
660}
661
662#[derive(Deserialize, Debug)]
663struct StreamingCompletionChunk {
664 choices: Vec<StreamingChoice>,
665 usage: Option<Usage>,
666}
667
668#[derive(Clone, Deserialize, Serialize, Debug)]
669pub struct StreamingCompletionResponse {
670 pub usage: Usage,
671}
672
673impl GetTokenUsage for StreamingCompletionResponse {
674 fn token_usage(&self) -> Option<crate::completion::Usage> {
675 let mut usage = crate::completion::Usage::new();
676 usage.input_tokens = self.usage.prompt_tokens as u64;
677 usage.output_tokens = self.usage.completion_tokens as u64;
678 usage.total_tokens = self.usage.total_tokens as u64;
679
680 Some(usage)
681 }
682}
683
684pub async fn send_compatible_streaming_request<T>(
685 http_client: T,
686 req: Request<Vec<u8>>,
687) -> Result<
688 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
689 CompletionError,
690>
691where
692 T: HttpClientExt + Clone + 'static,
693{
694 let span = tracing::Span::current();
695 let mut event_source = GenericEventSource::new(http_client, req);
696
697 let stream = stream! {
698 let mut final_usage = Usage::new();
699 let mut text_response = String::new();
700 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
701
702 while let Some(event_result) = event_source.next().await {
703 match event_result {
704 Ok(Event::Open) => {
705 tracing::trace!("SSE connection opened");
706 continue;
707 }
708 Ok(Event::Message(message)) => {
709 if message.data.trim().is_empty() || message.data == "[DONE]" {
710 continue;
711 }
712
713 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
714 let Ok(data) = parsed else {
715 let err = parsed.unwrap_err();
716 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
717 continue;
718 };
719
720 if let Some(choice) = data.choices.first() {
721 let delta = &choice.delta;
722
723 if !delta.tool_calls.is_empty() {
724 for tool_call in &delta.tool_calls {
725 let function = &tool_call.function;
726
727 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
729 && empty_or_none(&function.arguments)
730 {
731 let id = tool_call.id.clone().unwrap_or_default();
732 let name = function.name.clone().unwrap();
733 calls.insert(tool_call.index, (id, name, String::new()));
734 }
735 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
737 && let Some(arguments) = &function.arguments
738 && !arguments.is_empty()
739 {
740 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
741 let combined = format!("{}{}", existing_args, arguments);
742 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
743 } else {
744 tracing::debug!("Partial tool call received but tool call was never started.");
745 }
746 }
747 else {
749 let id = tool_call.id.clone().unwrap_or_default();
750 let name = function.name.clone().unwrap_or_default();
751 let arguments_str = function.arguments.clone().unwrap_or_default();
752
753 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
754 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
755 continue;
756 };
757
758 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
759 id,
760 name,
761 arguments: arguments_json,
762 call_id: None,
763 });
764 }
765 }
766 }
767
768 if let Some(content) = &delta.reasoning_content {
770 yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
771 reasoning: content.to_string(),
772 id: None,
773 signature: None,
774 });
775 }
776
777 if let Some(content) = &delta.content {
778 text_response += content;
779 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
780 }
781 }
782
783 if let Some(usage) = data.usage {
784 final_usage = usage.clone();
785 }
786 }
787 Err(crate::http_client::Error::StreamEnded) => {
788 break;
789 }
790 Err(err) => {
791 tracing::error!(?err, "SSE error");
792 yield Err(CompletionError::ResponseError(err.to_string()));
793 break;
794 }
795 }
796 }
797
798 event_source.close();
799
800 let mut tool_calls = Vec::new();
801 for (index, (id, name, arguments)) in calls {
803 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
804 continue;
805 };
806
807 tool_calls.push(ToolCall {
808 id: id.clone(),
809 index,
810 r#type: ToolType::Function,
811 function: Function {
812 name: name.clone(),
813 arguments: arguments_json.clone()
814 }
815 });
816 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
817 id,
818 name,
819 arguments: arguments_json,
820 call_id: None,
821 });
822 }
823
824 let message = Message::Assistant {
825 content: text_response,
826 name: None,
827 tool_calls
828 };
829
830 span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
831
832 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
833 StreamingCompletionResponse { usage: final_usage.clone() }
834 ));
835 };
836
837 Ok(crate::streaming::StreamingCompletionResponse::stream(
838 Box::pin(stream),
839 ))
840}
841
842pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
846pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
847
848#[cfg(test)]
850mod tests {
851
852 use super::*;
853
854 #[test]
855 fn test_deserialize_vec_choice() {
856 let data = r#"[{
857 "finish_reason": "stop",
858 "index": 0,
859 "logprobs": null,
860 "message":{"role":"assistant","content":"Hello, world!"}
861 }]"#;
862
863 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
864 assert_eq!(choices.len(), 1);
865 match &choices.first().unwrap().message {
866 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
867 _ => panic!("Expected assistant message"),
868 }
869 }
870
871 #[test]
872 fn test_deserialize_deepseek_response() {
873 let data = r#"{
874 "choices":[{
875 "finish_reason": "stop",
876 "index": 0,
877 "logprobs": null,
878 "message":{"role":"assistant","content":"Hello, world!"}
879 }],
880 "usage": {
881 "completion_tokens": 0,
882 "prompt_tokens": 0,
883 "prompt_cache_hit_tokens": 0,
884 "prompt_cache_miss_tokens": 0,
885 "total_tokens": 0
886 }
887 }"#;
888
889 let jd = &mut serde_json::Deserializer::from_str(data);
890 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
891 match result {
892 Ok(response) => match &response.choices.first().unwrap().message {
893 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
894 _ => panic!("Expected assistant message"),
895 },
896 Err(err) => {
897 panic!("Deserialization error at {}: {}", err.path(), err);
898 }
899 }
900 }
901
902 #[test]
903 fn test_deserialize_example_response() {
904 let data = r#"
905 {
906 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
907 "object": "chat.completion",
908 "created": 0,
909 "model": "deepseek-chat",
910 "choices": [
911 {
912 "index": 0,
913 "message": {
914 "role": "assistant",
915 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
916 },
917 "logprobs": null,
918 "finish_reason": "stop"
919 }
920 ],
921 "usage": {
922 "prompt_tokens": 13,
923 "completion_tokens": 32,
924 "total_tokens": 45,
925 "prompt_tokens_details": {
926 "cached_tokens": 0
927 },
928 "prompt_cache_hit_tokens": 0,
929 "prompt_cache_miss_tokens": 13
930 },
931 "system_fingerprint": "fp_4b6881f2c5"
932 }
933 "#;
934 let jd = &mut serde_json::Deserializer::from_str(data);
935 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
936
937 match result {
938 Ok(response) => match &response.choices.first().unwrap().message {
939 Message::Assistant { content, .. } => assert_eq!(
940 content,
941 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
942 ),
943 _ => panic!("Expected assistant message"),
944 },
945 Err(err) => {
946 panic!("Deserialization error at {}: {}", err.path(), err);
947 }
948 }
949 }
950
951 #[test]
952 fn test_serialize_deserialize_tool_call_message() {
953 let tool_call_choice_json = r#"
954 {
955 "finish_reason": "tool_calls",
956 "index": 0,
957 "logprobs": null,
958 "message": {
959 "content": "",
960 "role": "assistant",
961 "tool_calls": [
962 {
963 "function": {
964 "arguments": "{\"x\":2,\"y\":5}",
965 "name": "subtract"
966 },
967 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
968 "index": 0,
969 "type": "function"
970 }
971 ]
972 }
973 }
974 "#;
975
976 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
977
978 let expected_choice: Choice = Choice {
979 finish_reason: "tool_calls".to_string(),
980 index: 0,
981 logprobs: None,
982 message: Message::Assistant {
983 content: "".to_string(),
984 name: None,
985 tool_calls: vec![ToolCall {
986 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
987 function: Function {
988 name: "subtract".to_string(),
989 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
990 },
991 index: 0,
992 r#type: ToolType::Function,
993 }],
994 },
995 };
996
997 assert_eq!(choice, expected_choice);
998 }
999}