1use futures::StreamExt;
13use std::collections::HashMap;
14
15use crate::client::{
16 ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
17};
18use crate::completion::GetTokenUsage;
19use crate::json_utils::merge;
20use crate::message::Document;
21use crate::{
22 OneOrMany,
23 completion::{self, CompletionError, CompletionRequest},
24 impl_conversion_traits, json_utils, message,
25};
26use reqwest::Client as HttpClient;
27use serde::{Deserialize, Serialize};
28use serde_json::json;
29
30use super::openai::StreamingToolCall;
31
32const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
36
37pub struct ClientBuilder<'a> {
38 api_key: &'a str,
39 base_url: &'a str,
40 http_client: Option<reqwest::Client>,
41}
42
43impl<'a> ClientBuilder<'a> {
44 pub fn new(api_key: &'a str) -> Self {
45 Self {
46 api_key,
47 base_url: DEEPSEEK_API_BASE_URL,
48 http_client: None,
49 }
50 }
51
52 pub fn base_url(mut self, base_url: &'a str) -> Self {
53 self.base_url = base_url;
54 self
55 }
56
57 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
58 self.http_client = Some(client);
59 self
60 }
61
62 pub fn build(self) -> Result<Client, ClientBuilderError> {
63 let http_client = if let Some(http_client) = self.http_client {
64 http_client
65 } else {
66 reqwest::Client::builder().build()?
67 };
68
69 Ok(Client {
70 base_url: self.base_url.to_string(),
71 api_key: self.api_key.to_string(),
72 http_client,
73 })
74 }
75}
76
77#[derive(Clone)]
78pub struct Client {
79 pub base_url: String,
80 api_key: String,
81 http_client: HttpClient,
82}
83
84impl std::fmt::Debug for Client {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("Client")
87 .field("base_url", &self.base_url)
88 .field("http_client", &self.http_client)
89 .field("api_key", &"<REDACTED>")
90 .finish()
91 }
92}
93
94impl Client {
95 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
106 ClientBuilder::new(api_key)
107 }
108
109 pub fn new(api_key: &str) -> Self {
114 Self::builder(api_key)
115 .build()
116 .expect("DeepSeek client should build")
117 }
118
119 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
120 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
121 self.http_client.post(url).bearer_auth(&self.api_key)
122 }
123
124 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
125 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
126 self.http_client.get(url).bearer_auth(&self.api_key)
127 }
128}
129
130impl ProviderClient for Client {
131 fn from_env() -> Self {
133 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
134 Self::new(&api_key)
135 }
136
137 fn from_val(input: crate::client::ProviderValue) -> Self {
138 let crate::client::ProviderValue::Simple(api_key) = input else {
139 panic!("Incorrect provider value type")
140 };
141 Self::new(&api_key)
142 }
143}
144
145impl CompletionClient for Client {
146 type CompletionModel = CompletionModel;
147
148 fn completion_model(&self, model_name: &str) -> CompletionModel {
150 CompletionModel {
151 client: self.clone(),
152 model: model_name.to_string(),
153 }
154 }
155}
156
157impl VerifyClient for Client {
158 #[cfg_attr(feature = "worker", worker::send)]
159 async fn verify(&self) -> Result<(), VerifyError> {
160 let response = self.get("/user/balance").send().await?;
161 match response.status() {
162 reqwest::StatusCode::OK => Ok(()),
163 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
164 reqwest::StatusCode::INTERNAL_SERVER_ERROR
165 | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
166 Err(VerifyError::ProviderError(response.text().await?))
167 }
168 _ => {
169 response.error_for_status()?;
170 Ok(())
171 }
172 }
173 }
174}
175
176impl_conversion_traits!(
177 AsEmbeddings,
178 AsTranscription,
179 AsImageGeneration,
180 AsAudioGeneration for Client
181);
182
183#[derive(Debug, Deserialize)]
184struct ApiErrorResponse {
185 message: String,
186}
187
188#[derive(Debug, Deserialize)]
189#[serde(untagged)]
190enum ApiResponse<T> {
191 Ok(T),
192 Err(ApiErrorResponse),
193}
194
195impl From<ApiErrorResponse> for CompletionError {
196 fn from(err: ApiErrorResponse) -> Self {
197 CompletionError::ProviderError(err.message)
198 }
199}
200
201#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct CompletionResponse {
204 pub choices: Vec<Choice>,
206 pub usage: Usage,
207 }
209
210#[derive(Clone, Debug, Serialize, Deserialize, Default)]
211pub struct Usage {
212 pub completion_tokens: u32,
213 pub prompt_tokens: u32,
214 pub prompt_cache_hit_tokens: u32,
215 pub prompt_cache_miss_tokens: u32,
216 pub total_tokens: u32,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 pub completion_tokens_details: Option<CompletionTokensDetails>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub prompt_tokens_details: Option<PromptTokensDetails>,
221}
222
223impl Usage {
224 fn new() -> Self {
225 Self {
226 completion_tokens: 0,
227 prompt_tokens: 0,
228 prompt_cache_hit_tokens: 0,
229 prompt_cache_miss_tokens: 0,
230 total_tokens: 0,
231 completion_tokens_details: None,
232 prompt_tokens_details: None,
233 }
234 }
235}
236
237#[derive(Clone, Debug, Serialize, Deserialize, Default)]
238pub struct CompletionTokensDetails {
239 #[serde(skip_serializing_if = "Option::is_none")]
240 pub reasoning_tokens: Option<u32>,
241}
242
243#[derive(Clone, Debug, Serialize, Deserialize, Default)]
244pub struct PromptTokensDetails {
245 #[serde(skip_serializing_if = "Option::is_none")]
246 pub cached_tokens: Option<u32>,
247}
248
249#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
250pub struct Choice {
251 pub index: usize,
252 pub message: Message,
253 pub logprobs: Option<serde_json::Value>,
254 pub finish_reason: String,
255}
256
257#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
258#[serde(tag = "role", rename_all = "lowercase")]
259pub enum Message {
260 System {
261 content: String,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 name: Option<String>,
264 },
265 User {
266 content: String,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 name: Option<String>,
269 },
270 Assistant {
271 content: String,
272 #[serde(skip_serializing_if = "Option::is_none")]
273 name: Option<String>,
274 #[serde(
275 default,
276 deserialize_with = "json_utils::null_or_vec",
277 skip_serializing_if = "Vec::is_empty"
278 )]
279 tool_calls: Vec<ToolCall>,
280 },
281 #[serde(rename = "tool")]
282 ToolResult {
283 tool_call_id: String,
284 content: String,
285 },
286}
287
288impl Message {
289 pub fn system(content: &str) -> Self {
290 Message::System {
291 content: content.to_owned(),
292 name: None,
293 }
294 }
295}
296
297impl From<message::ToolResult> for Message {
298 fn from(tool_result: message::ToolResult) -> Self {
299 let content = match tool_result.content.first() {
300 message::ToolResultContent::Text(text) => text.text,
301 message::ToolResultContent::Image(_) => String::from("[Image]"),
302 };
303
304 Message::ToolResult {
305 tool_call_id: tool_result.id,
306 content,
307 }
308 }
309}
310
311impl From<message::ToolCall> for ToolCall {
312 fn from(tool_call: message::ToolCall) -> Self {
313 Self {
314 id: tool_call.id,
315 index: 0,
317 r#type: ToolType::Function,
318 function: Function {
319 name: tool_call.function.name,
320 arguments: tool_call.function.arguments,
321 },
322 }
323 }
324}
325
326impl TryFrom<message::Message> for Vec<Message> {
327 type Error = message::MessageError;
328
329 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
330 match message {
331 message::Message::User { content } => {
332 let mut messages = vec![];
334
335 let tool_results = content
336 .clone()
337 .into_iter()
338 .filter_map(|content| match content {
339 message::UserContent::ToolResult(tool_result) => {
340 Some(Message::from(tool_result))
341 }
342 _ => None,
343 })
344 .collect::<Vec<_>>();
345
346 messages.extend(tool_results);
347
348 let text_messages = content
350 .into_iter()
351 .filter_map(|content| match content {
352 message::UserContent::Text(text) => Some(Message::User {
353 content: text.text,
354 name: None,
355 }),
356 message::UserContent::Document(Document { data, .. }) => {
357 Some(Message::User {
358 content: data,
359 name: None,
360 })
361 }
362 _ => None,
363 })
364 .collect::<Vec<_>>();
365 messages.extend(text_messages);
366
367 Ok(messages)
368 }
369 message::Message::Assistant { content, .. } => {
370 let mut messages: Vec<Message> = vec![];
371
372 let tool_calls = content
374 .clone()
375 .into_iter()
376 .filter_map(|content| match content {
377 message::AssistantContent::ToolCall(tool_call) => {
378 Some(ToolCall::from(tool_call))
379 }
380 _ => None,
381 })
382 .collect::<Vec<_>>();
383
384 if !tool_calls.is_empty() {
386 messages.push(Message::Assistant {
387 content: "".to_string(),
388 name: None,
389 tool_calls,
390 });
391 }
392
393 let text_content = content
395 .into_iter()
396 .filter_map(|content| match content {
397 message::AssistantContent::Text(text) => Some(Message::Assistant {
398 content: text.text,
399 name: None,
400 tool_calls: vec![],
401 }),
402 _ => None,
403 })
404 .collect::<Vec<_>>();
405
406 messages.extend(text_content);
407
408 Ok(messages)
409 }
410 }
411 }
412}
413
414#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
415pub struct ToolCall {
416 pub id: String,
417 pub index: usize,
418 #[serde(default)]
419 pub r#type: ToolType,
420 pub function: Function,
421}
422
423#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
424pub struct Function {
425 pub name: String,
426 #[serde(with = "json_utils::stringified_json")]
427 pub arguments: serde_json::Value,
428}
429
430#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
431#[serde(rename_all = "lowercase")]
432pub enum ToolType {
433 #[default]
434 Function,
435}
436
437#[derive(Clone, Debug, Deserialize, Serialize)]
438pub struct ToolDefinition {
439 pub r#type: String,
440 pub function: completion::ToolDefinition,
441}
442
443impl From<crate::completion::ToolDefinition> for ToolDefinition {
444 fn from(tool: crate::completion::ToolDefinition) -> Self {
445 Self {
446 r#type: "function".into(),
447 function: tool,
448 }
449 }
450}
451
452impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
453 type Error = CompletionError;
454
455 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
456 let choice = response.choices.first().ok_or_else(|| {
457 CompletionError::ResponseError("Response contained no choices".to_owned())
458 })?;
459 let content = match &choice.message {
460 Message::Assistant {
461 content,
462 tool_calls,
463 ..
464 } => {
465 let mut content = if content.trim().is_empty() {
466 vec![]
467 } else {
468 vec![completion::AssistantContent::text(content)]
469 };
470
471 content.extend(
472 tool_calls
473 .iter()
474 .map(|call| {
475 completion::AssistantContent::tool_call(
476 &call.id,
477 &call.function.name,
478 call.function.arguments.clone(),
479 )
480 })
481 .collect::<Vec<_>>(),
482 );
483 Ok(content)
484 }
485 _ => Err(CompletionError::ResponseError(
486 "Response did not contain a valid message or tool call".into(),
487 )),
488 }?;
489
490 let choice = OneOrMany::many(content).map_err(|_| {
491 CompletionError::ResponseError(
492 "Response contained no message or tool call (empty)".to_owned(),
493 )
494 })?;
495
496 let usage = completion::Usage {
497 input_tokens: response.usage.prompt_tokens as u64,
498 output_tokens: response.usage.completion_tokens as u64,
499 total_tokens: response.usage.total_tokens as u64,
500 };
501
502 Ok(completion::CompletionResponse {
503 choice,
504 usage,
505 raw_response: response,
506 })
507 }
508}
509
510#[derive(Clone)]
512pub struct CompletionModel {
513 pub client: Client,
514 pub model: String,
515}
516
517impl CompletionModel {
518 fn create_completion_request(
519 &self,
520 completion_request: CompletionRequest,
521 ) -> Result<serde_json::Value, CompletionError> {
522 let mut partial_history = vec![];
524
525 if let Some(docs) = completion_request.normalized_documents() {
526 partial_history.push(docs);
527 }
528
529 partial_history.extend(completion_request.chat_history);
530
531 let mut full_history: Vec<Message> = completion_request
533 .preamble
534 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
535
536 full_history.extend(
538 partial_history
539 .into_iter()
540 .map(message::Message::try_into)
541 .collect::<Result<Vec<Vec<Message>>, _>>()?
542 .into_iter()
543 .flatten()
544 .collect::<Vec<_>>(),
545 );
546
547 let request = if completion_request.tools.is_empty() {
548 json!({
549 "model": self.model,
550 "messages": full_history,
551 "temperature": completion_request.temperature,
552 })
553 } else {
554 json!({
555 "model": self.model,
556 "messages": full_history,
557 "temperature": completion_request.temperature,
558 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
559 "tool_choice": "auto",
560 })
561 };
562
563 let request = if let Some(params) = completion_request.additional_params {
564 json_utils::merge(request, params)
565 } else {
566 request
567 };
568
569 Ok(request)
570 }
571}
572
573impl completion::CompletionModel for CompletionModel {
574 type Response = CompletionResponse;
575 type StreamingResponse = StreamingCompletionResponse;
576
577 #[cfg_attr(feature = "worker", worker::send)]
578 async fn completion(
579 &self,
580 completion_request: CompletionRequest,
581 ) -> Result<
582 completion::CompletionResponse<CompletionResponse>,
583 crate::completion::CompletionError,
584 > {
585 let request = self.create_completion_request(completion_request)?;
586
587 tracing::debug!("DeepSeek completion request: {request:?}");
588
589 let response = self
590 .client
591 .post("/chat/completions")
592 .json(&request)
593 .send()
594 .await?;
595
596 if response.status().is_success() {
597 let t = response.text().await?;
598 tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
599
600 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
601 ApiResponse::Ok(response) => response.try_into(),
602 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
603 }
604 } else {
605 Err(CompletionError::ProviderError(response.text().await?))
606 }
607 }
608
609 #[cfg_attr(feature = "worker", worker::send)]
610 async fn stream(
611 &self,
612 completion_request: CompletionRequest,
613 ) -> Result<
614 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
615 CompletionError,
616 > {
617 let mut request = self.create_completion_request(completion_request)?;
618
619 request = merge(
620 request,
621 json!({"stream": true, "stream_options": {"include_usage": true}}),
622 );
623
624 let builder = self.client.post("/chat/completions").json(&request);
625 send_compatible_streaming_request(builder).await
626 }
627}
628
629#[derive(Deserialize, Debug)]
630pub struct StreamingDelta {
631 #[serde(default)]
632 content: Option<String>,
633 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
634 tool_calls: Vec<StreamingToolCall>,
635 reasoning_content: Option<String>,
636}
637
638#[derive(Deserialize, Debug)]
639struct StreamingChoice {
640 delta: StreamingDelta,
641}
642
643#[derive(Deserialize, Debug)]
644struct StreamingCompletionChunk {
645 choices: Vec<StreamingChoice>,
646 usage: Option<Usage>,
647}
648
649#[derive(Clone, Deserialize, Serialize, Debug)]
650pub struct StreamingCompletionResponse {
651 pub usage: Usage,
652}
653
654impl GetTokenUsage for StreamingCompletionResponse {
655 fn token_usage(&self) -> Option<crate::completion::Usage> {
656 let mut usage = crate::completion::Usage::new();
657 usage.input_tokens = self.usage.prompt_tokens as u64;
658 usage.output_tokens = self.usage.completion_tokens as u64;
659 usage.total_tokens = self.usage.total_tokens as u64;
660
661 Some(usage)
662 }
663}
664
665pub async fn send_compatible_streaming_request(
666 request_builder: reqwest::RequestBuilder,
667) -> Result<
668 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
669 CompletionError,
670> {
671 let response = request_builder.send().await?;
672
673 if !response.status().is_success() {
674 return Err(CompletionError::ProviderError(format!(
675 "{}: {}",
676 response.status(),
677 response.text().await?
678 )));
679 }
680
681 let inner = Box::pin(async_stream::stream! {
683 let mut stream = response.bytes_stream();
684
685 let mut final_usage = Usage::new();
686 let mut partial_data = None;
687 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
688
689 while let Some(chunk_result) = stream.next().await {
690 let chunk = match chunk_result {
691 Ok(c) => c,
692 Err(e) => {
693 yield Err(CompletionError::from(e));
694 break;
695 }
696 };
697
698 let text = match String::from_utf8(chunk.to_vec()) {
699 Ok(t) => t,
700 Err(e) => {
701 yield Err(CompletionError::ResponseError(e.to_string()));
702 break;
703 }
704 };
705
706
707 for line in text.lines() {
708 let mut line = line.to_string();
709
710 if partial_data.is_some() {
712 line = format!("{}{}", partial_data.unwrap(), line);
713 partial_data = None;
714 }
715 else {
717 let Some(data) = line.strip_prefix("data:") else {
718 continue;
719 };
720
721 let data = data.trim_start();
722
723 if !line.ends_with("}") {
725 partial_data = Some(data.to_string());
726 } else {
727 line = data.to_string();
728 }
729 }
730
731 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
732
733 let Ok(data) = data else {
734 let err = data.unwrap_err();
735 tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
736 continue;
737 };
738
739
740 if let Some(choice) = data.choices.first() {
741 let delta = &choice.delta;
742
743
744 if !delta.tool_calls.is_empty() {
745 for tool_call in &delta.tool_calls {
746 let function = tool_call.function.clone();
747 if function.name.is_some() && function.arguments.is_empty() {
751 let id = tool_call.id.clone().unwrap_or("".to_string());
752
753 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
754 }
755 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
759 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
760 tracing::debug!("Partial tool call received but tool call was never started.");
761 continue;
762 };
763
764 let new_arguments = &function.arguments;
765 let arguments = format!("{arguments}{new_arguments}");
766
767 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
768 }
769 else {
771 let id = tool_call.id.clone().unwrap_or("".to_string());
772 let name = function.name.expect("function name should be present for complete tool call");
773 let arguments = function.arguments;
774 let Ok(arguments) = serde_json::from_str(&arguments) else {
775 tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
776 continue;
777 };
778
779 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
780 }
781 }
782 }
783
784 if let Some(content) = &delta.reasoning_content {
785 yield Ok(crate::streaming::RawStreamingChoice::Reasoning { reasoning: content.to_string(), id: None})
786 }
787
788 if let Some(content) = &delta.content {
789 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
790 }
791
792 }
793
794
795 if let Some(usage) = data.usage {
796 final_usage = usage.clone();
797 }
798 }
799 }
800
801 for (_, (id, name, arguments)) in calls {
802 let Ok(arguments) = serde_json::from_str(&arguments) else {
803 continue;
804 };
805
806 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
807 }
808
809 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
810 usage: final_usage.clone()
811 }))
812 });
813
814 Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
815}
816
817pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
823pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
825
826#[cfg(test)]
828mod tests {
829
830 use super::*;
831
832 #[test]
833 fn test_deserialize_vec_choice() {
834 let data = r#"[{
835 "finish_reason": "stop",
836 "index": 0,
837 "logprobs": null,
838 "message":{"role":"assistant","content":"Hello, world!"}
839 }]"#;
840
841 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
842 assert_eq!(choices.len(), 1);
843 match &choices.first().unwrap().message {
844 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
845 _ => panic!("Expected assistant message"),
846 }
847 }
848
849 #[test]
850 fn test_deserialize_deepseek_response() {
851 let data = r#"{
852 "choices":[{
853 "finish_reason": "stop",
854 "index": 0,
855 "logprobs": null,
856 "message":{"role":"assistant","content":"Hello, world!"}
857 }],
858 "usage": {
859 "completion_tokens": 0,
860 "prompt_tokens": 0,
861 "prompt_cache_hit_tokens": 0,
862 "prompt_cache_miss_tokens": 0,
863 "total_tokens": 0
864 }
865 }"#;
866
867 let jd = &mut serde_json::Deserializer::from_str(data);
868 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
869 match result {
870 Ok(response) => match &response.choices.first().unwrap().message {
871 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
872 _ => panic!("Expected assistant message"),
873 },
874 Err(err) => {
875 panic!("Deserialization error at {}: {}", err.path(), err);
876 }
877 }
878 }
879
880 #[test]
881 fn test_deserialize_example_response() {
882 let data = r#"
883 {
884 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
885 "object": "chat.completion",
886 "created": 0,
887 "model": "deepseek-chat",
888 "choices": [
889 {
890 "index": 0,
891 "message": {
892 "role": "assistant",
893 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
894 },
895 "logprobs": null,
896 "finish_reason": "stop"
897 }
898 ],
899 "usage": {
900 "prompt_tokens": 13,
901 "completion_tokens": 32,
902 "total_tokens": 45,
903 "prompt_tokens_details": {
904 "cached_tokens": 0
905 },
906 "prompt_cache_hit_tokens": 0,
907 "prompt_cache_miss_tokens": 13
908 },
909 "system_fingerprint": "fp_4b6881f2c5"
910 }
911 "#;
912 let jd = &mut serde_json::Deserializer::from_str(data);
913 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
914
915 match result {
916 Ok(response) => match &response.choices.first().unwrap().message {
917 Message::Assistant { content, .. } => assert_eq!(
918 content,
919 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
920 ),
921 _ => panic!("Expected assistant message"),
922 },
923 Err(err) => {
924 panic!("Deserialization error at {}: {}", err.path(), err);
925 }
926 }
927 }
928
929 #[test]
930 fn test_serialize_deserialize_tool_call_message() {
931 let tool_call_choice_json = r#"
932 {
933 "finish_reason": "tool_calls",
934 "index": 0,
935 "logprobs": null,
936 "message": {
937 "content": "",
938 "role": "assistant",
939 "tool_calls": [
940 {
941 "function": {
942 "arguments": "{\"x\":2,\"y\":5}",
943 "name": "subtract"
944 },
945 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
946 "index": 0,
947 "type": "function"
948 }
949 ]
950 }
951 }
952 "#;
953
954 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
955
956 let expected_choice: Choice = Choice {
957 finish_reason: "tool_calls".to_string(),
958 index: 0,
959 logprobs: None,
960 message: Message::Assistant {
961 content: "".to_string(),
962 name: None,
963 tool_calls: vec![ToolCall {
964 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
965 function: Function {
966 name: "subtract".to_string(),
967 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
968 },
969 index: 0,
970 r#type: ToolType::Function,
971 }],
972 },
973 };
974
975 assert_eq!(choice, expected_choice);
976 }
977}