1use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22 ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29 OneOrMany,
30 completion::{self, CompletionError, CompletionRequest},
31 json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50 type Builder = DeepSeekExtBuilder;
51
52 const VERIFY_PATH: &'static str = "/user/balance";
53
54 fn build<H>(
55 _: &crate::client::ClientBuilder<
56 Self::Builder,
57 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58 H,
59 >,
60 ) -> http_client::Result<Self> {
61 Ok(Self)
62 }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66 type Completion = Capable<CompletionModel<H>>;
67 type Embeddings = Nothing;
68 type Transcription = Nothing;
69 #[cfg(feature = "image")]
70 type ImageGeneration = Nothing;
71 #[cfg(feature = "audio")]
72 type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78 type Output = DeepSeekExt;
79 type ApiKey = DeepSeekApiKey;
80
81 const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88 type Input = DeepSeekApiKey;
89
90 fn from_env() -> Self {
92 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93 Self::new(&api_key).unwrap()
94 }
95
96 fn from_val(input: Self::Input) -> Self {
97 Self::new(input).unwrap()
98 }
99}
100
101#[derive(Debug, Deserialize)]
102struct ApiErrorResponse {
103 message: String,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(untagged)]
108enum ApiResponse<T> {
109 Ok(T),
110 Err(ApiErrorResponse),
111}
112
113impl From<ApiErrorResponse> for CompletionError {
114 fn from(err: ApiErrorResponse) -> Self {
115 CompletionError::ProviderError(err.message)
116 }
117}
118
119#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct CompletionResponse {
122 pub choices: Vec<Choice>,
124 pub usage: Usage,
125 }
127
128#[derive(Clone, Debug, Serialize, Deserialize, Default)]
129pub struct Usage {
130 pub completion_tokens: u32,
131 pub prompt_tokens: u32,
132 pub prompt_cache_hit_tokens: u32,
133 pub prompt_cache_miss_tokens: u32,
134 pub total_tokens: u32,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub completion_tokens_details: Option<CompletionTokensDetails>,
137 #[serde(skip_serializing_if = "Option::is_none")]
138 pub prompt_tokens_details: Option<PromptTokensDetails>,
139}
140
141impl Usage {
142 fn new() -> Self {
143 Self {
144 completion_tokens: 0,
145 prompt_tokens: 0,
146 prompt_cache_hit_tokens: 0,
147 prompt_cache_miss_tokens: 0,
148 total_tokens: 0,
149 completion_tokens_details: None,
150 prompt_tokens_details: None,
151 }
152 }
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize, Default)]
156pub struct CompletionTokensDetails {
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pub reasoning_tokens: Option<u32>,
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct PromptTokensDetails {
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub cached_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
168pub struct Choice {
169 pub index: usize,
170 pub message: Message,
171 pub logprobs: Option<serde_json::Value>,
172 pub finish_reason: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(tag = "role", rename_all = "lowercase")]
177pub enum Message {
178 System {
179 content: String,
180 #[serde(skip_serializing_if = "Option::is_none")]
181 name: Option<String>,
182 },
183 User {
184 content: String,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 name: Option<String>,
187 },
188 Assistant {
189 content: String,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 name: Option<String>,
192 #[serde(
193 default,
194 deserialize_with = "json_utils::null_or_vec",
195 skip_serializing_if = "Vec::is_empty"
196 )]
197 tool_calls: Vec<ToolCall>,
198 },
199 #[serde(rename = "tool")]
200 ToolResult {
201 tool_call_id: String,
202 content: String,
203 },
204}
205
206impl Message {
207 pub fn system(content: &str) -> Self {
208 Message::System {
209 content: content.to_owned(),
210 name: None,
211 }
212 }
213}
214
215impl From<message::ToolResult> for Message {
216 fn from(tool_result: message::ToolResult) -> Self {
217 let content = match tool_result.content.first() {
218 message::ToolResultContent::Text(text) => text.text,
219 message::ToolResultContent::Image(_) => String::from("[Image]"),
220 };
221
222 Message::ToolResult {
223 tool_call_id: tool_result.id,
224 content,
225 }
226 }
227}
228
229impl From<message::ToolCall> for ToolCall {
230 fn from(tool_call: message::ToolCall) -> Self {
231 Self {
232 id: tool_call.id,
233 index: 0,
235 r#type: ToolType::Function,
236 function: Function {
237 name: tool_call.function.name,
238 arguments: tool_call.function.arguments,
239 },
240 }
241 }
242}
243
244impl TryFrom<message::Message> for Vec<Message> {
245 type Error = message::MessageError;
246
247 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248 match message {
249 message::Message::User { content } => {
250 let mut messages = vec![];
252
253 let tool_results = content
254 .clone()
255 .into_iter()
256 .filter_map(|content| match content {
257 message::UserContent::ToolResult(tool_result) => {
258 Some(Message::from(tool_result))
259 }
260 _ => None,
261 })
262 .collect::<Vec<_>>();
263
264 messages.extend(tool_results);
265
266 let text_messages = content
268 .into_iter()
269 .filter_map(|content| match content {
270 message::UserContent::Text(text) => Some(Message::User {
271 content: text.text,
272 name: None,
273 }),
274 message::UserContent::Document(Document {
275 data:
276 DocumentSourceKind::Base64(content)
277 | DocumentSourceKind::String(content),
278 ..
279 }) => Some(Message::User {
280 content,
281 name: None,
282 }),
283 _ => None,
284 })
285 .collect::<Vec<_>>();
286 messages.extend(text_messages);
287
288 Ok(messages)
289 }
290 message::Message::Assistant { content, .. } => {
291 let mut messages: Vec<Message> = vec![];
292
293 let text_content = content
295 .clone()
296 .into_iter()
297 .filter_map(|content| match content {
298 message::AssistantContent::Text(text) => Some(Message::Assistant {
299 content: text.text,
300 name: None,
301 tool_calls: vec![],
302 }),
303 _ => None,
304 })
305 .collect::<Vec<_>>();
306
307 messages.extend(text_content);
308
309 let tool_calls = content
311 .clone()
312 .into_iter()
313 .filter_map(|content| match content {
314 message::AssistantContent::ToolCall(tool_call) => {
315 Some(ToolCall::from(tool_call))
316 }
317 _ => None,
318 })
319 .collect::<Vec<_>>();
320
321 if !tool_calls.is_empty() {
323 messages.push(Message::Assistant {
324 content: "".to_string(),
325 name: None,
326 tool_calls,
327 });
328 }
329
330 Ok(messages)
331 }
332 }
333 }
334}
335
336#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
337pub struct ToolCall {
338 pub id: String,
339 pub index: usize,
340 #[serde(default)]
341 pub r#type: ToolType,
342 pub function: Function,
343}
344
345#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
346pub struct Function {
347 pub name: String,
348 #[serde(with = "json_utils::stringified_json")]
349 pub arguments: serde_json::Value,
350}
351
352#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
353#[serde(rename_all = "lowercase")]
354pub enum ToolType {
355 #[default]
356 Function,
357}
358
359#[derive(Clone, Debug, Deserialize, Serialize)]
360pub struct ToolDefinition {
361 pub r#type: String,
362 pub function: completion::ToolDefinition,
363}
364
365impl From<crate::completion::ToolDefinition> for ToolDefinition {
366 fn from(tool: crate::completion::ToolDefinition) -> Self {
367 Self {
368 r#type: "function".into(),
369 function: tool,
370 }
371 }
372}
373
374impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
375 type Error = CompletionError;
376
377 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
378 let choice = response.choices.first().ok_or_else(|| {
379 CompletionError::ResponseError("Response contained no choices".to_owned())
380 })?;
381 let content = match &choice.message {
382 Message::Assistant {
383 content,
384 tool_calls,
385 ..
386 } => {
387 let mut content = if content.trim().is_empty() {
388 vec![]
389 } else {
390 vec![completion::AssistantContent::text(content)]
391 };
392
393 content.extend(
394 tool_calls
395 .iter()
396 .map(|call| {
397 completion::AssistantContent::tool_call(
398 &call.id,
399 &call.function.name,
400 call.function.arguments.clone(),
401 )
402 })
403 .collect::<Vec<_>>(),
404 );
405 Ok(content)
406 }
407 _ => Err(CompletionError::ResponseError(
408 "Response did not contain a valid message or tool call".into(),
409 )),
410 }?;
411
412 let choice = OneOrMany::many(content).map_err(|_| {
413 CompletionError::ResponseError(
414 "Response contained no message or tool call (empty)".to_owned(),
415 )
416 })?;
417
418 let usage = completion::Usage {
419 input_tokens: response.usage.prompt_tokens as u64,
420 output_tokens: response.usage.completion_tokens as u64,
421 total_tokens: response.usage.total_tokens as u64,
422 };
423
424 Ok(completion::CompletionResponse {
425 choice,
426 usage,
427 raw_response: response,
428 })
429 }
430}
431
432#[derive(Debug, Serialize, Deserialize)]
433pub(super) struct DeepseekCompletionRequest {
434 model: String,
435 pub messages: Vec<Message>,
436 #[serde(skip_serializing_if = "Option::is_none")]
437 temperature: Option<f64>,
438 #[serde(skip_serializing_if = "Vec::is_empty")]
439 tools: Vec<ToolDefinition>,
440 #[serde(skip_serializing_if = "Option::is_none")]
441 tool_choice: Option<crate::providers::openrouter::ToolChoice>,
442 #[serde(flatten, skip_serializing_if = "Option::is_none")]
443 pub additional_params: Option<serde_json::Value>,
444}
445
446impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
447 type Error = CompletionError;
448
449 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
450 let mut full_history: Vec<Message> = match &req.preamble {
451 Some(preamble) => vec![Message::system(preamble)],
452 None => vec![],
453 };
454
455 if let Some(docs) = req.normalized_documents() {
456 let docs: Vec<Message> = docs.try_into()?;
457 full_history.extend(docs);
458 }
459
460 let chat_history: Vec<Message> = req
461 .chat_history
462 .clone()
463 .into_iter()
464 .map(|message| message.try_into())
465 .collect::<Result<Vec<Vec<Message>>, _>>()?
466 .into_iter()
467 .flatten()
468 .collect();
469
470 full_history.extend(chat_history);
471
472 let tool_choice = req
473 .tool_choice
474 .clone()
475 .map(crate::providers::openrouter::ToolChoice::try_from)
476 .transpose()?;
477
478 Ok(Self {
479 model: model.to_string(),
480 messages: full_history,
481 temperature: req.temperature,
482 tools: req
483 .tools
484 .clone()
485 .into_iter()
486 .map(ToolDefinition::from)
487 .collect::<Vec<_>>(),
488 tool_choice,
489 additional_params: req.additional_params,
490 })
491 }
492}
493
494#[derive(Clone)]
496pub struct CompletionModel<T = reqwest::Client> {
497 pub client: Client<T>,
498 pub model: String,
499}
500
501impl<T> completion::CompletionModel for CompletionModel<T>
502where
503 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
504{
505 type Response = CompletionResponse;
506 type StreamingResponse = StreamingCompletionResponse;
507
508 type Client = Client<T>;
509
510 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
511 Self {
512 client: client.clone(),
513 model: model.into().to_string(),
514 }
515 }
516
517 async fn completion(
518 &self,
519 completion_request: CompletionRequest,
520 ) -> Result<
521 completion::CompletionResponse<CompletionResponse>,
522 crate::completion::CompletionError,
523 > {
524 let span = if tracing::Span::current().is_disabled() {
525 info_span!(
526 target: "rig::completions",
527 "chat",
528 gen_ai.operation.name = "chat",
529 gen_ai.provider.name = "deepseek",
530 gen_ai.request.model = self.model,
531 gen_ai.system_instructions = tracing::field::Empty,
532 gen_ai.response.id = tracing::field::Empty,
533 gen_ai.response.model = tracing::field::Empty,
534 gen_ai.usage.output_tokens = tracing::field::Empty,
535 gen_ai.usage.input_tokens = tracing::field::Empty,
536 )
537 } else {
538 tracing::Span::current()
539 };
540
541 span.record("gen_ai.system_instructions", &completion_request.preamble);
542
543 let request =
544 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
545
546 if enabled!(Level::TRACE) {
547 tracing::trace!(target: "rig::completions",
548 "DeepSeek completion request: {}",
549 serde_json::to_string_pretty(&request)?
550 );
551 }
552
553 let body = serde_json::to_vec(&request)?;
554 let req = self
555 .client
556 .post("/chat/completions")?
557 .body(body)
558 .map_err(|e| CompletionError::HttpError(e.into()))?;
559
560 async move {
561 let response = self.client.send::<_, Bytes>(req).await?;
562 let status = response.status();
563 let response_body = response.into_body().into_future().await?.to_vec();
564
565 if status.is_success() {
566 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
567 ApiResponse::Ok(response) => {
568 let span = tracing::Span::current();
569 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
570 span.record(
571 "gen_ai.usage.output_tokens",
572 response.usage.completion_tokens,
573 );
574 if enabled!(Level::TRACE) {
575 tracing::trace!(target: "rig::completions",
576 "DeepSeek completion response: {}",
577 serde_json::to_string_pretty(&response)?
578 );
579 }
580 response.try_into()
581 }
582 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
583 }
584 } else {
585 Err(CompletionError::ProviderError(
586 String::from_utf8_lossy(&response_body).to_string(),
587 ))
588 }
589 }
590 .instrument(span)
591 .await
592 }
593
594 async fn stream(
595 &self,
596 completion_request: CompletionRequest,
597 ) -> Result<
598 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
599 CompletionError,
600 > {
601 let preamble = completion_request.preamble.clone();
602 let mut request =
603 DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
604
605 let params = json_utils::merge(
606 request.additional_params.unwrap_or(serde_json::json!({})),
607 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
608 );
609
610 request.additional_params = Some(params);
611
612 if enabled!(Level::TRACE) {
613 tracing::trace!(target: "rig::completions",
614 "DeepSeek streaming completion request: {}",
615 serde_json::to_string_pretty(&request)?
616 );
617 }
618
619 let body = serde_json::to_vec(&request)?;
620
621 let req = self
622 .client
623 .post("/chat/completions")?
624 .body(body)
625 .map_err(|e| CompletionError::HttpError(e.into()))?;
626
627 let span = if tracing::Span::current().is_disabled() {
628 info_span!(
629 target: "rig::completions",
630 "chat_streaming",
631 gen_ai.operation.name = "chat_streaming",
632 gen_ai.provider.name = "deepseek",
633 gen_ai.request.model = self.model,
634 gen_ai.system_instructions = preamble,
635 gen_ai.response.id = tracing::field::Empty,
636 gen_ai.response.model = tracing::field::Empty,
637 gen_ai.usage.output_tokens = tracing::field::Empty,
638 gen_ai.usage.input_tokens = tracing::field::Empty,
639 )
640 } else {
641 tracing::Span::current()
642 };
643
644 tracing::Instrument::instrument(
645 send_compatible_streaming_request(self.client.clone(), req),
646 span,
647 )
648 .await
649 }
650}
651
652#[derive(Deserialize, Debug)]
653pub struct StreamingDelta {
654 #[serde(default)]
655 content: Option<String>,
656 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
657 tool_calls: Vec<StreamingToolCall>,
658 reasoning_content: Option<String>,
659}
660
661#[derive(Deserialize, Debug)]
662struct StreamingChoice {
663 delta: StreamingDelta,
664}
665
666#[derive(Deserialize, Debug)]
667struct StreamingCompletionChunk {
668 choices: Vec<StreamingChoice>,
669 usage: Option<Usage>,
670}
671
672#[derive(Clone, Deserialize, Serialize, Debug)]
673pub struct StreamingCompletionResponse {
674 pub usage: Usage,
675}
676
677impl GetTokenUsage for StreamingCompletionResponse {
678 fn token_usage(&self) -> Option<crate::completion::Usage> {
679 let mut usage = crate::completion::Usage::new();
680 usage.input_tokens = self.usage.prompt_tokens as u64;
681 usage.output_tokens = self.usage.completion_tokens as u64;
682 usage.total_tokens = self.usage.total_tokens as u64;
683
684 Some(usage)
685 }
686}
687
688pub async fn send_compatible_streaming_request<T>(
689 http_client: T,
690 req: Request<Vec<u8>>,
691) -> Result<
692 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
693 CompletionError,
694>
695where
696 T: HttpClientExt + Clone + 'static,
697{
698 let span = tracing::Span::current();
699 let mut event_source = GenericEventSource::new(http_client, req);
700
701 let stream = stream! {
702 let mut final_usage = Usage::new();
703 let mut text_response = String::new();
704 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
705
706 while let Some(event_result) = event_source.next().await {
707 match event_result {
708 Ok(Event::Open) => {
709 tracing::trace!("SSE connection opened");
710 continue;
711 }
712 Ok(Event::Message(message)) => {
713 if message.data.trim().is_empty() || message.data == "[DONE]" {
714 continue;
715 }
716
717 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
718 let Ok(data) = parsed else {
719 let err = parsed.unwrap_err();
720 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
721 continue;
722 };
723
724 if let Some(choice) = data.choices.first() {
725 let delta = &choice.delta;
726
727 if !delta.tool_calls.is_empty() {
728 for tool_call in &delta.tool_calls {
729 let function = &tool_call.function;
730
731 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
733 && empty_or_none(&function.arguments)
734 {
735 let id = tool_call.id.clone().unwrap_or_default();
736 let name = function.name.clone().unwrap();
737 calls.insert(tool_call.index, (id, name, String::new()));
738 }
739 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
741 && let Some(arguments) = &function.arguments
742 && !arguments.is_empty()
743 {
744 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
745 let combined = format!("{}{}", existing_args, arguments);
746 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
747 } else {
748 tracing::debug!("Partial tool call received but tool call was never started.");
749 }
750 }
751 else {
753 let id = tool_call.id.clone().unwrap_or_default();
754 let name = function.name.clone().unwrap_or_default();
755 let arguments_str = function.arguments.clone().unwrap_or_default();
756
757 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
758 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
759 continue;
760 };
761
762 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
763 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
764 ));
765 }
766 }
767 }
768
769 if let Some(content) = &delta.reasoning_content {
771 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
772 id: None,
773 reasoning: content.to_string()
774 });
775 }
776
777 if let Some(content) = &delta.content {
778 text_response += content;
779 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
780 }
781 }
782
783 if let Some(usage) = data.usage {
784 final_usage = usage.clone();
785 }
786 }
787 Err(crate::http_client::Error::StreamEnded) => {
788 break;
789 }
790 Err(err) => {
791 tracing::error!(?err, "SSE error");
792 yield Err(CompletionError::ResponseError(err.to_string()));
793 break;
794 }
795 }
796 }
797
798 event_source.close();
799
800 let mut tool_calls = Vec::new();
801 for (index, (id, name, arguments)) in calls {
803 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
804 continue;
805 };
806
807 tool_calls.push(ToolCall {
808 id: id.clone(),
809 index,
810 r#type: ToolType::Function,
811 function: Function {
812 name: name.clone(),
813 arguments: arguments_json.clone()
814 }
815 });
816 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
817 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
818 ));
819 }
820
821 let message = Message::Assistant {
822 content: text_response,
823 name: None,
824 tool_calls
825 };
826
827 span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
828
829 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
830 StreamingCompletionResponse { usage: final_usage.clone() }
831 ));
832 };
833
834 Ok(crate::streaming::StreamingCompletionResponse::stream(
835 Box::pin(stream),
836 ))
837}
838
839pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
843pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
844
845#[cfg(test)]
847mod tests {
848
849 use super::*;
850
851 #[test]
852 fn test_deserialize_vec_choice() {
853 let data = r#"[{
854 "finish_reason": "stop",
855 "index": 0,
856 "logprobs": null,
857 "message":{"role":"assistant","content":"Hello, world!"}
858 }]"#;
859
860 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
861 assert_eq!(choices.len(), 1);
862 match &choices.first().unwrap().message {
863 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
864 _ => panic!("Expected assistant message"),
865 }
866 }
867
868 #[test]
869 fn test_deserialize_deepseek_response() {
870 let data = r#"{
871 "choices":[{
872 "finish_reason": "stop",
873 "index": 0,
874 "logprobs": null,
875 "message":{"role":"assistant","content":"Hello, world!"}
876 }],
877 "usage": {
878 "completion_tokens": 0,
879 "prompt_tokens": 0,
880 "prompt_cache_hit_tokens": 0,
881 "prompt_cache_miss_tokens": 0,
882 "total_tokens": 0
883 }
884 }"#;
885
886 let jd = &mut serde_json::Deserializer::from_str(data);
887 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
888 match result {
889 Ok(response) => match &response.choices.first().unwrap().message {
890 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
891 _ => panic!("Expected assistant message"),
892 },
893 Err(err) => {
894 panic!("Deserialization error at {}: {}", err.path(), err);
895 }
896 }
897 }
898
899 #[test]
900 fn test_deserialize_example_response() {
901 let data = r#"
902 {
903 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
904 "object": "chat.completion",
905 "created": 0,
906 "model": "deepseek-chat",
907 "choices": [
908 {
909 "index": 0,
910 "message": {
911 "role": "assistant",
912 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
913 },
914 "logprobs": null,
915 "finish_reason": "stop"
916 }
917 ],
918 "usage": {
919 "prompt_tokens": 13,
920 "completion_tokens": 32,
921 "total_tokens": 45,
922 "prompt_tokens_details": {
923 "cached_tokens": 0
924 },
925 "prompt_cache_hit_tokens": 0,
926 "prompt_cache_miss_tokens": 13
927 },
928 "system_fingerprint": "fp_4b6881f2c5"
929 }
930 "#;
931 let jd = &mut serde_json::Deserializer::from_str(data);
932 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
933
934 match result {
935 Ok(response) => match &response.choices.first().unwrap().message {
936 Message::Assistant { content, .. } => assert_eq!(
937 content,
938 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
939 ),
940 _ => panic!("Expected assistant message"),
941 },
942 Err(err) => {
943 panic!("Deserialization error at {}: {}", err.path(), err);
944 }
945 }
946 }
947
948 #[test]
949 fn test_serialize_deserialize_tool_call_message() {
950 let tool_call_choice_json = r#"
951 {
952 "finish_reason": "tool_calls",
953 "index": 0,
954 "logprobs": null,
955 "message": {
956 "content": "",
957 "role": "assistant",
958 "tool_calls": [
959 {
960 "function": {
961 "arguments": "{\"x\":2,\"y\":5}",
962 "name": "subtract"
963 },
964 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
965 "index": 0,
966 "type": "function"
967 }
968 ]
969 }
970 }
971 "#;
972
973 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
974
975 let expected_choice: Choice = Choice {
976 finish_reason: "tool_calls".to_string(),
977 index: 0,
978 logprobs: None,
979 message: Message::Assistant {
980 content: "".to_string(),
981 name: None,
982 tool_calls: vec![ToolCall {
983 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
984 function: Function {
985 name: "subtract".to_string(),
986 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
987 },
988 index: 0,
989 r#type: ToolType::Function,
990 }],
991 },
992 };
993
994 assert_eq!(choice, expected_choice);
995 }
996}