1use futures::StreamExt;
13use std::collections::HashMap;
14
15use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
16use crate::json_utils::merge;
17use crate::message::Document;
18use crate::{
19 OneOrMany,
20 completion::{self, CompletionError, CompletionRequest},
21 impl_conversion_traits, json_utils, message,
22};
23use reqwest::Client as HttpClient;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26
27use super::openai::StreamingToolCall;
28
29const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
33
34pub struct ClientBuilder<'a> {
35 api_key: &'a str,
36 base_url: &'a str,
37 http_client: Option<reqwest::Client>,
38}
39
40impl<'a> ClientBuilder<'a> {
41 pub fn new(api_key: &'a str) -> Self {
42 Self {
43 api_key,
44 base_url: DEEPSEEK_API_BASE_URL,
45 http_client: None,
46 }
47 }
48
49 pub fn base_url(mut self, base_url: &'a str) -> Self {
50 self.base_url = base_url;
51 self
52 }
53
54 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
55 self.http_client = Some(client);
56 self
57 }
58
59 pub fn build(self) -> Result<Client, ClientBuilderError> {
60 let http_client = if let Some(http_client) = self.http_client {
61 http_client
62 } else {
63 reqwest::Client::builder().build()?
64 };
65
66 Ok(Client {
67 base_url: self.base_url.to_string(),
68 api_key: self.api_key.to_string(),
69 http_client,
70 })
71 }
72}
73
74#[derive(Clone)]
75pub struct Client {
76 pub base_url: String,
77 api_key: String,
78 http_client: HttpClient,
79}
80
81impl std::fmt::Debug for Client {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("Client")
84 .field("base_url", &self.base_url)
85 .field("http_client", &self.http_client)
86 .field("api_key", &"<REDACTED>")
87 .finish()
88 }
89}
90
91impl Client {
92 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
103 ClientBuilder::new(api_key)
104 }
105
106 pub fn new(api_key: &str) -> Self {
111 Self::builder(api_key)
112 .build()
113 .expect("DeepSeek client should build")
114 }
115
116 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
117 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
118 self.http_client.post(url).bearer_auth(&self.api_key)
119 }
120}
121
122impl ProviderClient for Client {
123 fn from_env() -> Self {
125 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
126 Self::new(&api_key)
127 }
128
129 fn from_val(input: crate::client::ProviderValue) -> Self {
130 let crate::client::ProviderValue::Simple(api_key) = input else {
131 panic!("Incorrect provider value type")
132 };
133 Self::new(&api_key)
134 }
135}
136
137impl CompletionClient for Client {
138 type CompletionModel = CompletionModel;
139
140 fn completion_model(&self, model_name: &str) -> CompletionModel {
142 CompletionModel {
143 client: self.clone(),
144 model: model_name.to_string(),
145 }
146 }
147}
148
149impl_conversion_traits!(
150 AsEmbeddings,
151 AsTranscription,
152 AsImageGeneration,
153 AsAudioGeneration for Client
154);
155
156#[derive(Debug, Deserialize)]
157struct ApiErrorResponse {
158 message: String,
159}
160
161#[derive(Debug, Deserialize)]
162#[serde(untagged)]
163enum ApiResponse<T> {
164 Ok(T),
165 Err(ApiErrorResponse),
166}
167
168impl From<ApiErrorResponse> for CompletionError {
169 fn from(err: ApiErrorResponse) -> Self {
170 CompletionError::ProviderError(err.message)
171 }
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct CompletionResponse {
177 pub choices: Vec<Choice>,
179 pub usage: Usage,
180 }
182
183#[derive(Clone, Debug, Serialize, Deserialize, Default)]
184pub struct Usage {
185 pub completion_tokens: u32,
186 pub prompt_tokens: u32,
187 pub prompt_cache_hit_tokens: u32,
188 pub prompt_cache_miss_tokens: u32,
189 pub total_tokens: u32,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 pub completion_tokens_details: Option<CompletionTokensDetails>,
192 #[serde(skip_serializing_if = "Option::is_none")]
193 pub prompt_tokens_details: Option<PromptTokensDetails>,
194}
195
196impl Usage {
197 fn new() -> Self {
198 Self {
199 completion_tokens: 0,
200 prompt_tokens: 0,
201 prompt_cache_hit_tokens: 0,
202 prompt_cache_miss_tokens: 0,
203 total_tokens: 0,
204 completion_tokens_details: None,
205 prompt_tokens_details: None,
206 }
207 }
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize, Default)]
211pub struct CompletionTokensDetails {
212 #[serde(skip_serializing_if = "Option::is_none")]
213 pub reasoning_tokens: Option<u32>,
214}
215
216#[derive(Clone, Debug, Serialize, Deserialize, Default)]
217pub struct PromptTokensDetails {
218 #[serde(skip_serializing_if = "Option::is_none")]
219 pub cached_tokens: Option<u32>,
220}
221
222#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
223pub struct Choice {
224 pub index: usize,
225 pub message: Message,
226 pub logprobs: Option<serde_json::Value>,
227 pub finish_reason: String,
228}
229
230#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
231#[serde(tag = "role", rename_all = "lowercase")]
232pub enum Message {
233 System {
234 content: String,
235 #[serde(skip_serializing_if = "Option::is_none")]
236 name: Option<String>,
237 },
238 User {
239 content: String,
240 #[serde(skip_serializing_if = "Option::is_none")]
241 name: Option<String>,
242 },
243 Assistant {
244 content: String,
245 #[serde(skip_serializing_if = "Option::is_none")]
246 name: Option<String>,
247 #[serde(
248 default,
249 deserialize_with = "json_utils::null_or_vec",
250 skip_serializing_if = "Vec::is_empty"
251 )]
252 tool_calls: Vec<ToolCall>,
253 },
254 #[serde(rename = "tool")]
255 ToolResult {
256 tool_call_id: String,
257 content: String,
258 },
259}
260
261impl Message {
262 pub fn system(content: &str) -> Self {
263 Message::System {
264 content: content.to_owned(),
265 name: None,
266 }
267 }
268}
269
270impl From<message::ToolResult> for Message {
271 fn from(tool_result: message::ToolResult) -> Self {
272 let content = match tool_result.content.first() {
273 message::ToolResultContent::Text(text) => text.text,
274 message::ToolResultContent::Image(_) => String::from("[Image]"),
275 };
276
277 Message::ToolResult {
278 tool_call_id: tool_result.id,
279 content,
280 }
281 }
282}
283
284impl From<message::ToolCall> for ToolCall {
285 fn from(tool_call: message::ToolCall) -> Self {
286 Self {
287 id: tool_call.id,
288 index: 0,
290 r#type: ToolType::Function,
291 function: Function {
292 name: tool_call.function.name,
293 arguments: tool_call.function.arguments,
294 },
295 }
296 }
297}
298
299impl TryFrom<message::Message> for Vec<Message> {
300 type Error = message::MessageError;
301
302 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
303 match message {
304 message::Message::User { content } => {
305 let mut messages = vec![];
307
308 let tool_results = content
309 .clone()
310 .into_iter()
311 .filter_map(|content| match content {
312 message::UserContent::ToolResult(tool_result) => {
313 Some(Message::from(tool_result))
314 }
315 _ => None,
316 })
317 .collect::<Vec<_>>();
318
319 messages.extend(tool_results);
320
321 let text_messages = content
323 .into_iter()
324 .filter_map(|content| match content {
325 message::UserContent::Text(text) => Some(Message::User {
326 content: text.text,
327 name: None,
328 }),
329 message::UserContent::Document(Document { data, .. }) => {
330 Some(Message::User {
331 content: data,
332 name: None,
333 })
334 }
335 _ => None,
336 })
337 .collect::<Vec<_>>();
338 messages.extend(text_messages);
339
340 Ok(messages)
341 }
342 message::Message::Assistant { content, .. } => {
343 let mut messages: Vec<Message> = vec![];
344
345 let tool_calls = content
347 .clone()
348 .into_iter()
349 .filter_map(|content| match content {
350 message::AssistantContent::ToolCall(tool_call) => {
351 Some(ToolCall::from(tool_call))
352 }
353 _ => None,
354 })
355 .collect::<Vec<_>>();
356
357 if !tool_calls.is_empty() {
359 messages.push(Message::Assistant {
360 content: "".to_string(),
361 name: None,
362 tool_calls,
363 });
364 }
365
366 let text_content = content
368 .into_iter()
369 .filter_map(|content| match content {
370 message::AssistantContent::Text(text) => Some(Message::Assistant {
371 content: text.text,
372 name: None,
373 tool_calls: vec![],
374 }),
375 _ => None,
376 })
377 .collect::<Vec<_>>();
378
379 messages.extend(text_content);
380
381 Ok(messages)
382 }
383 }
384 }
385}
386
387#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
388pub struct ToolCall {
389 pub id: String,
390 pub index: usize,
391 #[serde(default)]
392 pub r#type: ToolType,
393 pub function: Function,
394}
395
396#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
397pub struct Function {
398 pub name: String,
399 #[serde(with = "json_utils::stringified_json")]
400 pub arguments: serde_json::Value,
401}
402
403#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
404#[serde(rename_all = "lowercase")]
405pub enum ToolType {
406 #[default]
407 Function,
408}
409
410#[derive(Clone, Debug, Deserialize, Serialize)]
411pub struct ToolDefinition {
412 pub r#type: String,
413 pub function: completion::ToolDefinition,
414}
415
416impl From<crate::completion::ToolDefinition> for ToolDefinition {
417 fn from(tool: crate::completion::ToolDefinition) -> Self {
418 Self {
419 r#type: "function".into(),
420 function: tool,
421 }
422 }
423}
424
425impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
426 type Error = CompletionError;
427
428 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
429 let choice = response.choices.first().ok_or_else(|| {
430 CompletionError::ResponseError("Response contained no choices".to_owned())
431 })?;
432 let content = match &choice.message {
433 Message::Assistant {
434 content,
435 tool_calls,
436 ..
437 } => {
438 let mut content = if content.trim().is_empty() {
439 vec![]
440 } else {
441 vec![completion::AssistantContent::text(content)]
442 };
443
444 content.extend(
445 tool_calls
446 .iter()
447 .map(|call| {
448 completion::AssistantContent::tool_call(
449 &call.id,
450 &call.function.name,
451 call.function.arguments.clone(),
452 )
453 })
454 .collect::<Vec<_>>(),
455 );
456 Ok(content)
457 }
458 _ => Err(CompletionError::ResponseError(
459 "Response did not contain a valid message or tool call".into(),
460 )),
461 }?;
462
463 let choice = OneOrMany::many(content).map_err(|_| {
464 CompletionError::ResponseError(
465 "Response contained no message or tool call (empty)".to_owned(),
466 )
467 })?;
468
469 let usage = completion::Usage {
470 input_tokens: response.usage.prompt_tokens as u64,
471 output_tokens: response.usage.completion_tokens as u64,
472 total_tokens: response.usage.total_tokens as u64,
473 };
474
475 Ok(completion::CompletionResponse {
476 choice,
477 usage,
478 raw_response: response,
479 })
480 }
481}
482
483#[derive(Clone)]
485pub struct CompletionModel {
486 pub client: Client,
487 pub model: String,
488}
489
490impl CompletionModel {
491 fn create_completion_request(
492 &self,
493 completion_request: CompletionRequest,
494 ) -> Result<serde_json::Value, CompletionError> {
495 let mut partial_history = vec![];
497
498 if let Some(docs) = completion_request.normalized_documents() {
499 partial_history.push(docs);
500 }
501
502 partial_history.extend(completion_request.chat_history);
503
504 let mut full_history: Vec<Message> = completion_request
506 .preamble
507 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
508
509 full_history.extend(
511 partial_history
512 .into_iter()
513 .map(message::Message::try_into)
514 .collect::<Result<Vec<Vec<Message>>, _>>()?
515 .into_iter()
516 .flatten()
517 .collect::<Vec<_>>(),
518 );
519
520 let request = if completion_request.tools.is_empty() {
521 json!({
522 "model": self.model,
523 "messages": full_history,
524 "temperature": completion_request.temperature,
525 })
526 } else {
527 json!({
528 "model": self.model,
529 "messages": full_history,
530 "temperature": completion_request.temperature,
531 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
532 "tool_choice": "auto",
533 })
534 };
535
536 let request = if let Some(params) = completion_request.additional_params {
537 json_utils::merge(request, params)
538 } else {
539 request
540 };
541
542 Ok(request)
543 }
544}
545
546impl completion::CompletionModel for CompletionModel {
547 type Response = CompletionResponse;
548 type StreamingResponse = StreamingCompletionResponse;
549
550 #[cfg_attr(feature = "worker", worker::send)]
551 async fn completion(
552 &self,
553 completion_request: CompletionRequest,
554 ) -> Result<
555 completion::CompletionResponse<CompletionResponse>,
556 crate::completion::CompletionError,
557 > {
558 let request = self.create_completion_request(completion_request)?;
559
560 tracing::debug!("DeepSeek completion request: {request:?}");
561
562 let response = self
563 .client
564 .post("/chat/completions")
565 .json(&request)
566 .send()
567 .await?;
568
569 if response.status().is_success() {
570 let t = response.text().await?;
571 tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
572
573 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
574 ApiResponse::Ok(response) => response.try_into(),
575 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
576 }
577 } else {
578 Err(CompletionError::ProviderError(response.text().await?))
579 }
580 }
581
582 #[cfg_attr(feature = "worker", worker::send)]
583 async fn stream(
584 &self,
585 completion_request: CompletionRequest,
586 ) -> Result<
587 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
588 CompletionError,
589 > {
590 let mut request = self.create_completion_request(completion_request)?;
591
592 request = merge(
593 request,
594 json!({"stream": true, "stream_options": {"include_usage": true}}),
595 );
596
597 let builder = self.client.post("/v1/chat/completions").json(&request);
598 send_compatible_streaming_request(builder).await
599 }
600}
601
602#[derive(Deserialize, Debug)]
603pub struct StreamingDelta {
604 #[serde(default)]
605 content: Option<String>,
606 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
607 tool_calls: Vec<StreamingToolCall>,
608 reasoning_content: Option<String>,
609}
610
611#[derive(Deserialize, Debug)]
612struct StreamingChoice {
613 delta: StreamingDelta,
614}
615
616#[derive(Deserialize, Debug)]
617struct StreamingCompletionChunk {
618 choices: Vec<StreamingChoice>,
619 usage: Option<Usage>,
620}
621
622#[derive(Clone, Deserialize, Serialize, Debug)]
623pub struct StreamingCompletionResponse {
624 pub usage: Usage,
625}
626
627pub async fn send_compatible_streaming_request(
628 request_builder: reqwest::RequestBuilder,
629) -> Result<
630 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
631 CompletionError,
632> {
633 let response = request_builder.send().await?;
634
635 if !response.status().is_success() {
636 return Err(CompletionError::ProviderError(format!(
637 "{}: {}",
638 response.status(),
639 response.text().await?
640 )));
641 }
642
643 let inner = Box::pin(async_stream::stream! {
645 let mut stream = response.bytes_stream();
646
647 let mut final_usage = Usage::new();
648 let mut partial_data = None;
649 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
650
651 while let Some(chunk_result) = stream.next().await {
652 let chunk = match chunk_result {
653 Ok(c) => c,
654 Err(e) => {
655 yield Err(CompletionError::from(e));
656 break;
657 }
658 };
659
660 let text = match String::from_utf8(chunk.to_vec()) {
661 Ok(t) => t,
662 Err(e) => {
663 yield Err(CompletionError::ResponseError(e.to_string()));
664 break;
665 }
666 };
667
668
669 for line in text.lines() {
670 let mut line = line.to_string();
671
672 if partial_data.is_some() {
674 line = format!("{}{}", partial_data.unwrap(), line);
675 partial_data = None;
676 }
677 else {
679 let Some(data) = line.strip_prefix("data:") else {
680 continue;
681 };
682
683 let data = data.trim_start();
684
685 if !line.ends_with("}") {
687 partial_data = Some(data.to_string());
688 } else {
689 line = data.to_string();
690 }
691 }
692
693 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
694
695 let Ok(data) = data else {
696 let err = data.unwrap_err();
697 tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
698 continue;
699 };
700
701
702 if let Some(choice) = data.choices.first() {
703 let delta = &choice.delta;
704
705
706 if !delta.tool_calls.is_empty() {
707 for tool_call in &delta.tool_calls {
708 let function = tool_call.function.clone();
709 if function.name.is_some() && function.arguments.is_empty() {
713 let id = tool_call.id.clone().unwrap_or("".to_string());
714
715 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
716 }
717 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
721 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
722 tracing::debug!("Partial tool call received but tool call was never started.");
723 continue;
724 };
725
726 let new_arguments = &function.arguments;
727 let arguments = format!("{arguments}{new_arguments}");
728
729 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
730 }
731 else {
733 let id = tool_call.id.clone().unwrap_or("".to_string());
734 let name = function.name.expect("function name should be present for complete tool call");
735 let arguments = function.arguments;
736 let Ok(arguments) = serde_json::from_str(&arguments) else {
737 tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
738 continue;
739 };
740
741 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
742 }
743 }
744 }
745
746 if let Some(content) = &delta.reasoning_content {
747 yield Ok(crate::streaming::RawStreamingChoice::Reasoning { reasoning: content.to_string()})
748 }
749
750 if let Some(content) = &delta.content {
751 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
752 }
753
754 }
755
756
757 if let Some(usage) = data.usage {
758 final_usage = usage.clone();
759 }
760 }
761 }
762
763 for (_, (id, name, arguments)) in calls {
764 let Ok(arguments) = serde_json::from_str(&arguments) else {
765 continue;
766 };
767
768 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
769 }
770
771 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
772 usage: final_usage.clone()
773 }))
774 });
775
776 Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
777}
778
779pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
785pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
787
788#[cfg(test)]
790mod tests {
791
792 use super::*;
793
794 #[test]
795 fn test_deserialize_vec_choice() {
796 let data = r#"[{
797 "finish_reason": "stop",
798 "index": 0,
799 "logprobs": null,
800 "message":{"role":"assistant","content":"Hello, world!"}
801 }]"#;
802
803 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
804 assert_eq!(choices.len(), 1);
805 match &choices.first().unwrap().message {
806 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
807 _ => panic!("Expected assistant message"),
808 }
809 }
810
811 #[test]
812 fn test_deserialize_deepseek_response() {
813 let data = r#"{
814 "choices":[{
815 "finish_reason": "stop",
816 "index": 0,
817 "logprobs": null,
818 "message":{"role":"assistant","content":"Hello, world!"}
819 }],
820 "usage": {
821 "completion_tokens": 0,
822 "prompt_tokens": 0,
823 "prompt_cache_hit_tokens": 0,
824 "prompt_cache_miss_tokens": 0,
825 "total_tokens": 0
826 }
827 }"#;
828
829 let jd = &mut serde_json::Deserializer::from_str(data);
830 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
831 match result {
832 Ok(response) => match &response.choices.first().unwrap().message {
833 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
834 _ => panic!("Expected assistant message"),
835 },
836 Err(err) => {
837 panic!("Deserialization error at {}: {}", err.path(), err);
838 }
839 }
840 }
841
842 #[test]
843 fn test_deserialize_example_response() {
844 let data = r#"
845 {
846 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
847 "object": "chat.completion",
848 "created": 0,
849 "model": "deepseek-chat",
850 "choices": [
851 {
852 "index": 0,
853 "message": {
854 "role": "assistant",
855 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
856 },
857 "logprobs": null,
858 "finish_reason": "stop"
859 }
860 ],
861 "usage": {
862 "prompt_tokens": 13,
863 "completion_tokens": 32,
864 "total_tokens": 45,
865 "prompt_tokens_details": {
866 "cached_tokens": 0
867 },
868 "prompt_cache_hit_tokens": 0,
869 "prompt_cache_miss_tokens": 13
870 },
871 "system_fingerprint": "fp_4b6881f2c5"
872 }
873 "#;
874 let jd = &mut serde_json::Deserializer::from_str(data);
875 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
876
877 match result {
878 Ok(response) => match &response.choices.first().unwrap().message {
879 Message::Assistant { content, .. } => assert_eq!(
880 content,
881 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
882 ),
883 _ => panic!("Expected assistant message"),
884 },
885 Err(err) => {
886 panic!("Deserialization error at {}: {}", err.path(), err);
887 }
888 }
889 }
890
891 #[test]
892 fn test_serialize_deserialize_tool_call_message() {
893 let tool_call_choice_json = r#"
894 {
895 "finish_reason": "tool_calls",
896 "index": 0,
897 "logprobs": null,
898 "message": {
899 "content": "",
900 "role": "assistant",
901 "tool_calls": [
902 {
903 "function": {
904 "arguments": "{\"x\":2,\"y\":5}",
905 "name": "subtract"
906 },
907 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
908 "index": 0,
909 "type": "function"
910 }
911 ]
912 }
913 }
914 "#;
915
916 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
917
918 let expected_choice: Choice = Choice {
919 finish_reason: "tool_calls".to_string(),
920 index: 0,
921 logprobs: None,
922 message: Message::Assistant {
923 content: "".to_string(),
924 name: None,
925 tool_calls: vec![ToolCall {
926 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
927 function: Function {
928 name: "subtract".to_string(),
929 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
930 },
931 index: 0,
932 r#type: ToolType::Function,
933 }],
934 },
935 };
936
937 assert_eq!(choice, expected_choice);
938 }
939}