1use async_stream::stream;
13use futures::StreamExt;
14use reqwest_eventsource::{Event, RequestBuilderExt};
15use std::collections::HashMap;
16
17use crate::client::{
18 ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
19};
20use crate::completion::GetTokenUsage;
21use crate::json_utils::merge;
22use crate::message::Document;
23use crate::{
24 OneOrMany,
25 completion::{self, CompletionError, CompletionRequest},
26 impl_conversion_traits, json_utils, message,
27};
28use reqwest::Client as HttpClient;
29use serde::{Deserialize, Serialize};
30use serde_json::json;
31
32use super::openai::StreamingToolCall;
33
34const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
38
39pub struct ClientBuilder<'a> {
40 api_key: &'a str,
41 base_url: &'a str,
42 http_client: Option<reqwest::Client>,
43}
44
45impl<'a> ClientBuilder<'a> {
46 pub fn new(api_key: &'a str) -> Self {
47 Self {
48 api_key,
49 base_url: DEEPSEEK_API_BASE_URL,
50 http_client: None,
51 }
52 }
53
54 pub fn base_url(mut self, base_url: &'a str) -> Self {
55 self.base_url = base_url;
56 self
57 }
58
59 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
60 self.http_client = Some(client);
61 self
62 }
63
64 pub fn build(self) -> Result<Client, ClientBuilderError> {
65 let http_client = if let Some(http_client) = self.http_client {
66 http_client
67 } else {
68 reqwest::Client::builder().build()?
69 };
70
71 Ok(Client {
72 base_url: self.base_url.to_string(),
73 api_key: self.api_key.to_string(),
74 http_client,
75 })
76 }
77}
78
79#[derive(Clone)]
80pub struct Client {
81 pub base_url: String,
82 api_key: String,
83 http_client: HttpClient,
84}
85
86impl std::fmt::Debug for Client {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("Client")
89 .field("base_url", &self.base_url)
90 .field("http_client", &self.http_client)
91 .field("api_key", &"<REDACTED>")
92 .finish()
93 }
94}
95
96impl Client {
97 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
108 ClientBuilder::new(api_key)
109 }
110
111 pub fn new(api_key: &str) -> Self {
116 Self::builder(api_key)
117 .build()
118 .expect("DeepSeek client should build")
119 }
120
121 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
122 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
123 self.http_client.post(url).bearer_auth(&self.api_key)
124 }
125
126 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
127 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
128 self.http_client.get(url).bearer_auth(&self.api_key)
129 }
130}
131
132impl ProviderClient for Client {
133 fn from_env() -> Self {
135 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
136 Self::new(&api_key)
137 }
138
139 fn from_val(input: crate::client::ProviderValue) -> Self {
140 let crate::client::ProviderValue::Simple(api_key) = input else {
141 panic!("Incorrect provider value type")
142 };
143 Self::new(&api_key)
144 }
145}
146
147impl CompletionClient for Client {
148 type CompletionModel = CompletionModel;
149
150 fn completion_model(&self, model_name: &str) -> CompletionModel {
152 CompletionModel {
153 client: self.clone(),
154 model: model_name.to_string(),
155 }
156 }
157}
158
159impl VerifyClient for Client {
160 #[cfg_attr(feature = "worker", worker::send)]
161 async fn verify(&self) -> Result<(), VerifyError> {
162 let response = self.get("/user/balance").send().await?;
163 match response.status() {
164 reqwest::StatusCode::OK => Ok(()),
165 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
166 reqwest::StatusCode::INTERNAL_SERVER_ERROR
167 | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
168 Err(VerifyError::ProviderError(response.text().await?))
169 }
170 _ => {
171 response.error_for_status()?;
172 Ok(())
173 }
174 }
175 }
176}
177
178impl_conversion_traits!(
179 AsEmbeddings,
180 AsTranscription,
181 AsImageGeneration,
182 AsAudioGeneration for Client
183);
184
185#[derive(Debug, Deserialize)]
186struct ApiErrorResponse {
187 message: String,
188}
189
190#[derive(Debug, Deserialize)]
191#[serde(untagged)]
192enum ApiResponse<T> {
193 Ok(T),
194 Err(ApiErrorResponse),
195}
196
197impl From<ApiErrorResponse> for CompletionError {
198 fn from(err: ApiErrorResponse) -> Self {
199 CompletionError::ProviderError(err.message)
200 }
201}
202
203#[derive(Clone, Debug, Serialize, Deserialize)]
205pub struct CompletionResponse {
206 pub choices: Vec<Choice>,
208 pub usage: Usage,
209 }
211
212#[derive(Clone, Debug, Serialize, Deserialize, Default)]
213pub struct Usage {
214 pub completion_tokens: u32,
215 pub prompt_tokens: u32,
216 pub prompt_cache_hit_tokens: u32,
217 pub prompt_cache_miss_tokens: u32,
218 pub total_tokens: u32,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub completion_tokens_details: Option<CompletionTokensDetails>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 pub prompt_tokens_details: Option<PromptTokensDetails>,
223}
224
225impl Usage {
226 fn new() -> Self {
227 Self {
228 completion_tokens: 0,
229 prompt_tokens: 0,
230 prompt_cache_hit_tokens: 0,
231 prompt_cache_miss_tokens: 0,
232 total_tokens: 0,
233 completion_tokens_details: None,
234 prompt_tokens_details: None,
235 }
236 }
237}
238
239#[derive(Clone, Debug, Serialize, Deserialize, Default)]
240pub struct CompletionTokensDetails {
241 #[serde(skip_serializing_if = "Option::is_none")]
242 pub reasoning_tokens: Option<u32>,
243}
244
245#[derive(Clone, Debug, Serialize, Deserialize, Default)]
246pub struct PromptTokensDetails {
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub cached_tokens: Option<u32>,
249}
250
251#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
252pub struct Choice {
253 pub index: usize,
254 pub message: Message,
255 pub logprobs: Option<serde_json::Value>,
256 pub finish_reason: String,
257}
258
259#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
260#[serde(tag = "role", rename_all = "lowercase")]
261pub enum Message {
262 System {
263 content: String,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 name: Option<String>,
266 },
267 User {
268 content: String,
269 #[serde(skip_serializing_if = "Option::is_none")]
270 name: Option<String>,
271 },
272 Assistant {
273 content: String,
274 #[serde(skip_serializing_if = "Option::is_none")]
275 name: Option<String>,
276 #[serde(
277 default,
278 deserialize_with = "json_utils::null_or_vec",
279 skip_serializing_if = "Vec::is_empty"
280 )]
281 tool_calls: Vec<ToolCall>,
282 },
283 #[serde(rename = "tool")]
284 ToolResult {
285 tool_call_id: String,
286 content: String,
287 },
288}
289
290impl Message {
291 pub fn system(content: &str) -> Self {
292 Message::System {
293 content: content.to_owned(),
294 name: None,
295 }
296 }
297}
298
299impl From<message::ToolResult> for Message {
300 fn from(tool_result: message::ToolResult) -> Self {
301 let content = match tool_result.content.first() {
302 message::ToolResultContent::Text(text) => text.text,
303 message::ToolResultContent::Image(_) => String::from("[Image]"),
304 };
305
306 Message::ToolResult {
307 tool_call_id: tool_result.id,
308 content,
309 }
310 }
311}
312
313impl From<message::ToolCall> for ToolCall {
314 fn from(tool_call: message::ToolCall) -> Self {
315 Self {
316 id: tool_call.id,
317 index: 0,
319 r#type: ToolType::Function,
320 function: Function {
321 name: tool_call.function.name,
322 arguments: tool_call.function.arguments,
323 },
324 }
325 }
326}
327
328impl TryFrom<message::Message> for Vec<Message> {
329 type Error = message::MessageError;
330
331 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
332 match message {
333 message::Message::User { content } => {
334 let mut messages = vec![];
336
337 let tool_results = content
338 .clone()
339 .into_iter()
340 .filter_map(|content| match content {
341 message::UserContent::ToolResult(tool_result) => {
342 Some(Message::from(tool_result))
343 }
344 _ => None,
345 })
346 .collect::<Vec<_>>();
347
348 messages.extend(tool_results);
349
350 let text_messages = content
352 .into_iter()
353 .filter_map(|content| match content {
354 message::UserContent::Text(text) => Some(Message::User {
355 content: text.text,
356 name: None,
357 }),
358 message::UserContent::Document(Document { data, .. }) => {
359 Some(Message::User {
360 content: data,
361 name: None,
362 })
363 }
364 _ => None,
365 })
366 .collect::<Vec<_>>();
367 messages.extend(text_messages);
368
369 Ok(messages)
370 }
371 message::Message::Assistant { content, .. } => {
372 let mut messages: Vec<Message> = vec![];
373
374 let tool_calls = content
376 .clone()
377 .into_iter()
378 .filter_map(|content| match content {
379 message::AssistantContent::ToolCall(tool_call) => {
380 Some(ToolCall::from(tool_call))
381 }
382 _ => None,
383 })
384 .collect::<Vec<_>>();
385
386 if !tool_calls.is_empty() {
388 messages.push(Message::Assistant {
389 content: "".to_string(),
390 name: None,
391 tool_calls,
392 });
393 }
394
395 let text_content = content
397 .into_iter()
398 .filter_map(|content| match content {
399 message::AssistantContent::Text(text) => Some(Message::Assistant {
400 content: text.text,
401 name: None,
402 tool_calls: vec![],
403 }),
404 _ => None,
405 })
406 .collect::<Vec<_>>();
407
408 messages.extend(text_content);
409
410 Ok(messages)
411 }
412 }
413 }
414}
415
416#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
417pub struct ToolCall {
418 pub id: String,
419 pub index: usize,
420 #[serde(default)]
421 pub r#type: ToolType,
422 pub function: Function,
423}
424
425#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
426pub struct Function {
427 pub name: String,
428 #[serde(with = "json_utils::stringified_json")]
429 pub arguments: serde_json::Value,
430}
431
432#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
433#[serde(rename_all = "lowercase")]
434pub enum ToolType {
435 #[default]
436 Function,
437}
438
439#[derive(Clone, Debug, Deserialize, Serialize)]
440pub struct ToolDefinition {
441 pub r#type: String,
442 pub function: completion::ToolDefinition,
443}
444
445impl From<crate::completion::ToolDefinition> for ToolDefinition {
446 fn from(tool: crate::completion::ToolDefinition) -> Self {
447 Self {
448 r#type: "function".into(),
449 function: tool,
450 }
451 }
452}
453
454impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
455 type Error = CompletionError;
456
457 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
458 let choice = response.choices.first().ok_or_else(|| {
459 CompletionError::ResponseError("Response contained no choices".to_owned())
460 })?;
461 let content = match &choice.message {
462 Message::Assistant {
463 content,
464 tool_calls,
465 ..
466 } => {
467 let mut content = if content.trim().is_empty() {
468 vec![]
469 } else {
470 vec![completion::AssistantContent::text(content)]
471 };
472
473 content.extend(
474 tool_calls
475 .iter()
476 .map(|call| {
477 completion::AssistantContent::tool_call(
478 &call.id,
479 &call.function.name,
480 call.function.arguments.clone(),
481 )
482 })
483 .collect::<Vec<_>>(),
484 );
485 Ok(content)
486 }
487 _ => Err(CompletionError::ResponseError(
488 "Response did not contain a valid message or tool call".into(),
489 )),
490 }?;
491
492 let choice = OneOrMany::many(content).map_err(|_| {
493 CompletionError::ResponseError(
494 "Response contained no message or tool call (empty)".to_owned(),
495 )
496 })?;
497
498 let usage = completion::Usage {
499 input_tokens: response.usage.prompt_tokens as u64,
500 output_tokens: response.usage.completion_tokens as u64,
501 total_tokens: response.usage.total_tokens as u64,
502 };
503
504 Ok(completion::CompletionResponse {
505 choice,
506 usage,
507 raw_response: response,
508 })
509 }
510}
511
512#[derive(Clone)]
514pub struct CompletionModel {
515 pub client: Client,
516 pub model: String,
517}
518
519impl CompletionModel {
520 fn create_completion_request(
521 &self,
522 completion_request: CompletionRequest,
523 ) -> Result<serde_json::Value, CompletionError> {
524 let mut partial_history = vec![];
526
527 if let Some(docs) = completion_request.normalized_documents() {
528 partial_history.push(docs);
529 }
530
531 partial_history.extend(completion_request.chat_history);
532
533 let mut full_history: Vec<Message> = completion_request
535 .preamble
536 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
537
538 full_history.extend(
540 partial_history
541 .into_iter()
542 .map(message::Message::try_into)
543 .collect::<Result<Vec<Vec<Message>>, _>>()?
544 .into_iter()
545 .flatten()
546 .collect::<Vec<_>>(),
547 );
548
549 let request = if completion_request.tools.is_empty() {
550 json!({
551 "model": self.model,
552 "messages": full_history,
553 "temperature": completion_request.temperature,
554 })
555 } else {
556 json!({
557 "model": self.model,
558 "messages": full_history,
559 "temperature": completion_request.temperature,
560 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
561 "tool_choice": "auto",
562 })
563 };
564
565 let request = if let Some(params) = completion_request.additional_params {
566 json_utils::merge(request, params)
567 } else {
568 request
569 };
570
571 Ok(request)
572 }
573}
574
575impl completion::CompletionModel for CompletionModel {
576 type Response = CompletionResponse;
577 type StreamingResponse = StreamingCompletionResponse;
578
579 #[cfg_attr(feature = "worker", worker::send)]
580 async fn completion(
581 &self,
582 completion_request: CompletionRequest,
583 ) -> Result<
584 completion::CompletionResponse<CompletionResponse>,
585 crate::completion::CompletionError,
586 > {
587 let request = self.create_completion_request(completion_request)?;
588
589 tracing::debug!("DeepSeek completion request: {request:?}");
590
591 let response = self
592 .client
593 .post("/chat/completions")
594 .json(&request)
595 .send()
596 .await?;
597
598 if response.status().is_success() {
599 let t = response.text().await?;
600 tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
601
602 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
603 ApiResponse::Ok(response) => response.try_into(),
604 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
605 }
606 } else {
607 Err(CompletionError::ProviderError(response.text().await?))
608 }
609 }
610
611 #[cfg_attr(feature = "worker", worker::send)]
612 async fn stream(
613 &self,
614 completion_request: CompletionRequest,
615 ) -> Result<
616 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
617 CompletionError,
618 > {
619 let mut request = self.create_completion_request(completion_request)?;
620
621 request = merge(
622 request,
623 json!({"stream": true, "stream_options": {"include_usage": true}}),
624 );
625
626 let builder = self.client.post("/chat/completions").json(&request);
627 send_compatible_streaming_request(builder).await
628 }
629}
630
631#[derive(Deserialize, Debug)]
632pub struct StreamingDelta {
633 #[serde(default)]
634 content: Option<String>,
635 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
636 tool_calls: Vec<StreamingToolCall>,
637 reasoning_content: Option<String>,
638}
639
640#[derive(Deserialize, Debug)]
641struct StreamingChoice {
642 delta: StreamingDelta,
643}
644
645#[derive(Deserialize, Debug)]
646struct StreamingCompletionChunk {
647 choices: Vec<StreamingChoice>,
648 usage: Option<Usage>,
649}
650
651#[derive(Clone, Deserialize, Serialize, Debug)]
652pub struct StreamingCompletionResponse {
653 pub usage: Usage,
654}
655
656impl GetTokenUsage for StreamingCompletionResponse {
657 fn token_usage(&self) -> Option<crate::completion::Usage> {
658 let mut usage = crate::completion::Usage::new();
659 usage.input_tokens = self.usage.prompt_tokens as u64;
660 usage.output_tokens = self.usage.completion_tokens as u64;
661 usage.total_tokens = self.usage.total_tokens as u64;
662
663 Some(usage)
664 }
665}
666
667pub async fn send_compatible_streaming_request(
668 request_builder: reqwest::RequestBuilder,
669) -> Result<
670 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
671 CompletionError,
672> {
673 let mut event_source = request_builder
674 .eventsource()
675 .expect("Cloning request must succeed");
676
677 let stream = Box::pin(stream! {
678 let mut final_usage = Usage::new();
679 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
680
681 while let Some(event_result) = event_source.next().await {
682 match event_result {
683 Ok(Event::Open) => {
684 tracing::trace!("SSE connection opened");
685 continue;
686 }
687 Ok(Event::Message(message)) => {
688 if message.data.trim().is_empty() || message.data == "[DONE]" {
689 continue;
690 }
691
692 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
693 let Ok(data) = parsed else {
694 let err = parsed.unwrap_err();
695 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
696 continue;
697 };
698
699 if let Some(choice) = data.choices.first() {
700 let delta = &choice.delta;
701
702 if !delta.tool_calls.is_empty() {
703 for tool_call in &delta.tool_calls {
704 let function = &tool_call.function;
705
706 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
708 && function.arguments.is_empty()
709 {
710 let id = tool_call.id.clone().unwrap_or_default();
711 let name = function.name.clone().unwrap();
712 calls.insert(tool_call.index, (id, name, String::new()));
713 }
714 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
716 && !function.arguments.is_empty()
717 {
718 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
719 let combined = format!("{}{}", existing_args, function.arguments);
720 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
721 } else {
722 tracing::debug!("Partial tool call received but tool call was never started.");
723 }
724 }
725 else {
727 let id = tool_call.id.clone().unwrap_or_default();
728 let name = function.name.clone().unwrap_or_default();
729 let arguments_str = function.arguments.clone();
730
731 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
732 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
733 continue;
734 };
735
736 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
737 id,
738 name,
739 arguments: arguments_json,
740 call_id: None,
741 });
742 }
743 }
744 }
745
746 if let Some(content) = &delta.reasoning_content {
748 yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
749 reasoning: content.to_string(),
750 id: None,
751 });
752 }
753
754 if let Some(content) = &delta.content {
755 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
756 }
757 }
758
759 if let Some(usage) = data.usage {
760 final_usage = usage.clone();
761 }
762 }
763 Err(reqwest_eventsource::Error::StreamEnded) => {
764 break;
765 }
766 Err(err) => {
767 tracing::error!(?err, "SSE error");
768 yield Err(CompletionError::ResponseError(err.to_string()));
769 break;
770 }
771 }
772 }
773
774 for (_, (id, name, arguments)) in calls {
776 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
777 continue;
778 };
779 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
780 id,
781 name,
782 arguments: arguments_json,
783 call_id: None,
784 });
785 }
786
787 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
788 StreamingCompletionResponse { usage: final_usage.clone() }
789 ));
790 });
791
792 Ok(crate::streaming::StreamingCompletionResponse::stream(
793 stream,
794 ))
795}
796
797pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
803pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
805
806#[cfg(test)]
808mod tests {
809
810 use super::*;
811
812 #[test]
813 fn test_deserialize_vec_choice() {
814 let data = r#"[{
815 "finish_reason": "stop",
816 "index": 0,
817 "logprobs": null,
818 "message":{"role":"assistant","content":"Hello, world!"}
819 }]"#;
820
821 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
822 assert_eq!(choices.len(), 1);
823 match &choices.first().unwrap().message {
824 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
825 _ => panic!("Expected assistant message"),
826 }
827 }
828
829 #[test]
830 fn test_deserialize_deepseek_response() {
831 let data = r#"{
832 "choices":[{
833 "finish_reason": "stop",
834 "index": 0,
835 "logprobs": null,
836 "message":{"role":"assistant","content":"Hello, world!"}
837 }],
838 "usage": {
839 "completion_tokens": 0,
840 "prompt_tokens": 0,
841 "prompt_cache_hit_tokens": 0,
842 "prompt_cache_miss_tokens": 0,
843 "total_tokens": 0
844 }
845 }"#;
846
847 let jd = &mut serde_json::Deserializer::from_str(data);
848 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
849 match result {
850 Ok(response) => match &response.choices.first().unwrap().message {
851 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
852 _ => panic!("Expected assistant message"),
853 },
854 Err(err) => {
855 panic!("Deserialization error at {}: {}", err.path(), err);
856 }
857 }
858 }
859
860 #[test]
861 fn test_deserialize_example_response() {
862 let data = r#"
863 {
864 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
865 "object": "chat.completion",
866 "created": 0,
867 "model": "deepseek-chat",
868 "choices": [
869 {
870 "index": 0,
871 "message": {
872 "role": "assistant",
873 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
874 },
875 "logprobs": null,
876 "finish_reason": "stop"
877 }
878 ],
879 "usage": {
880 "prompt_tokens": 13,
881 "completion_tokens": 32,
882 "total_tokens": 45,
883 "prompt_tokens_details": {
884 "cached_tokens": 0
885 },
886 "prompt_cache_hit_tokens": 0,
887 "prompt_cache_miss_tokens": 13
888 },
889 "system_fingerprint": "fp_4b6881f2c5"
890 }
891 "#;
892 let jd = &mut serde_json::Deserializer::from_str(data);
893 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
894
895 match result {
896 Ok(response) => match &response.choices.first().unwrap().message {
897 Message::Assistant { content, .. } => assert_eq!(
898 content,
899 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
900 ),
901 _ => panic!("Expected assistant message"),
902 },
903 Err(err) => {
904 panic!("Deserialization error at {}: {}", err.path(), err);
905 }
906 }
907 }
908
909 #[test]
910 fn test_serialize_deserialize_tool_call_message() {
911 let tool_call_choice_json = r#"
912 {
913 "finish_reason": "tool_calls",
914 "index": 0,
915 "logprobs": null,
916 "message": {
917 "content": "",
918 "role": "assistant",
919 "tool_calls": [
920 {
921 "function": {
922 "arguments": "{\"x\":2,\"y\":5}",
923 "name": "subtract"
924 },
925 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
926 "index": 0,
927 "type": "function"
928 }
929 ]
930 }
931 }
932 "#;
933
934 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
935
936 let expected_choice: Choice = Choice {
937 finish_reason: "tool_calls".to_string(),
938 index: 0,
939 logprobs: None,
940 message: Message::Assistant {
941 content: "".to_string(),
942 name: None,
943 tool_calls: vec![ToolCall {
944 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
945 function: Function {
946 name: "subtract".to_string(),
947 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
948 },
949 index: 0,
950 r#type: ToolType::Function,
951 }],
952 },
953 };
954
955 assert_eq!(choice, expected_choice);
956 }
957}