1use futures::StreamExt;
13use std::collections::HashMap;
14
15use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
16use crate::completion::GetTokenUsage;
17use crate::json_utils::merge;
18use crate::message::Document;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 impl_conversion_traits, json_utils, message,
23};
24use reqwest::Client as HttpClient;
25use serde::{Deserialize, Serialize};
26use serde_json::json;
27
28use super::openai::StreamingToolCall;
29
30const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
34
35pub struct ClientBuilder<'a> {
36 api_key: &'a str,
37 base_url: &'a str,
38 http_client: Option<reqwest::Client>,
39}
40
41impl<'a> ClientBuilder<'a> {
42 pub fn new(api_key: &'a str) -> Self {
43 Self {
44 api_key,
45 base_url: DEEPSEEK_API_BASE_URL,
46 http_client: None,
47 }
48 }
49
50 pub fn base_url(mut self, base_url: &'a str) -> Self {
51 self.base_url = base_url;
52 self
53 }
54
55 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
56 self.http_client = Some(client);
57 self
58 }
59
60 pub fn build(self) -> Result<Client, ClientBuilderError> {
61 let http_client = if let Some(http_client) = self.http_client {
62 http_client
63 } else {
64 reqwest::Client::builder().build()?
65 };
66
67 Ok(Client {
68 base_url: self.base_url.to_string(),
69 api_key: self.api_key.to_string(),
70 http_client,
71 })
72 }
73}
74
75#[derive(Clone)]
76pub struct Client {
77 pub base_url: String,
78 api_key: String,
79 http_client: HttpClient,
80}
81
82impl std::fmt::Debug for Client {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct("Client")
85 .field("base_url", &self.base_url)
86 .field("http_client", &self.http_client)
87 .field("api_key", &"<REDACTED>")
88 .finish()
89 }
90}
91
92impl Client {
93 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
104 ClientBuilder::new(api_key)
105 }
106
107 pub fn new(api_key: &str) -> Self {
112 Self::builder(api_key)
113 .build()
114 .expect("DeepSeek client should build")
115 }
116
117 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
118 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
119 self.http_client.post(url).bearer_auth(&self.api_key)
120 }
121}
122
123impl ProviderClient for Client {
124 fn from_env() -> Self {
126 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
127 Self::new(&api_key)
128 }
129
130 fn from_val(input: crate::client::ProviderValue) -> Self {
131 let crate::client::ProviderValue::Simple(api_key) = input else {
132 panic!("Incorrect provider value type")
133 };
134 Self::new(&api_key)
135 }
136}
137
138impl CompletionClient for Client {
139 type CompletionModel = CompletionModel;
140
141 fn completion_model(&self, model_name: &str) -> CompletionModel {
143 CompletionModel {
144 client: self.clone(),
145 model: model_name.to_string(),
146 }
147 }
148}
149
150impl_conversion_traits!(
151 AsEmbeddings,
152 AsTranscription,
153 AsImageGeneration,
154 AsAudioGeneration for Client
155);
156
157#[derive(Debug, Deserialize)]
158struct ApiErrorResponse {
159 message: String,
160}
161
162#[derive(Debug, Deserialize)]
163#[serde(untagged)]
164enum ApiResponse<T> {
165 Ok(T),
166 Err(ApiErrorResponse),
167}
168
169impl From<ApiErrorResponse> for CompletionError {
170 fn from(err: ApiErrorResponse) -> Self {
171 CompletionError::ProviderError(err.message)
172 }
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize)]
177pub struct CompletionResponse {
178 pub choices: Vec<Choice>,
180 pub usage: Usage,
181 }
183
184#[derive(Clone, Debug, Serialize, Deserialize, Default)]
185pub struct Usage {
186 pub completion_tokens: u32,
187 pub prompt_tokens: u32,
188 pub prompt_cache_hit_tokens: u32,
189 pub prompt_cache_miss_tokens: u32,
190 pub total_tokens: u32,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 pub completion_tokens_details: Option<CompletionTokensDetails>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 pub prompt_tokens_details: Option<PromptTokensDetails>,
195}
196
197impl Usage {
198 fn new() -> Self {
199 Self {
200 completion_tokens: 0,
201 prompt_tokens: 0,
202 prompt_cache_hit_tokens: 0,
203 prompt_cache_miss_tokens: 0,
204 total_tokens: 0,
205 completion_tokens_details: None,
206 prompt_tokens_details: None,
207 }
208 }
209}
210
211#[derive(Clone, Debug, Serialize, Deserialize, Default)]
212pub struct CompletionTokensDetails {
213 #[serde(skip_serializing_if = "Option::is_none")]
214 pub reasoning_tokens: Option<u32>,
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize, Default)]
218pub struct PromptTokensDetails {
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub cached_tokens: Option<u32>,
221}
222
223#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
224pub struct Choice {
225 pub index: usize,
226 pub message: Message,
227 pub logprobs: Option<serde_json::Value>,
228 pub finish_reason: String,
229}
230
231#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
232#[serde(tag = "role", rename_all = "lowercase")]
233pub enum Message {
234 System {
235 content: String,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 name: Option<String>,
238 },
239 User {
240 content: String,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 name: Option<String>,
243 },
244 Assistant {
245 content: String,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 name: Option<String>,
248 #[serde(
249 default,
250 deserialize_with = "json_utils::null_or_vec",
251 skip_serializing_if = "Vec::is_empty"
252 )]
253 tool_calls: Vec<ToolCall>,
254 },
255 #[serde(rename = "tool")]
256 ToolResult {
257 tool_call_id: String,
258 content: String,
259 },
260}
261
262impl Message {
263 pub fn system(content: &str) -> Self {
264 Message::System {
265 content: content.to_owned(),
266 name: None,
267 }
268 }
269}
270
271impl From<message::ToolResult> for Message {
272 fn from(tool_result: message::ToolResult) -> Self {
273 let content = match tool_result.content.first() {
274 message::ToolResultContent::Text(text) => text.text,
275 message::ToolResultContent::Image(_) => String::from("[Image]"),
276 };
277
278 Message::ToolResult {
279 tool_call_id: tool_result.id,
280 content,
281 }
282 }
283}
284
285impl From<message::ToolCall> for ToolCall {
286 fn from(tool_call: message::ToolCall) -> Self {
287 Self {
288 id: tool_call.id,
289 index: 0,
291 r#type: ToolType::Function,
292 function: Function {
293 name: tool_call.function.name,
294 arguments: tool_call.function.arguments,
295 },
296 }
297 }
298}
299
300impl TryFrom<message::Message> for Vec<Message> {
301 type Error = message::MessageError;
302
303 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
304 match message {
305 message::Message::User { content } => {
306 let mut messages = vec![];
308
309 let tool_results = content
310 .clone()
311 .into_iter()
312 .filter_map(|content| match content {
313 message::UserContent::ToolResult(tool_result) => {
314 Some(Message::from(tool_result))
315 }
316 _ => None,
317 })
318 .collect::<Vec<_>>();
319
320 messages.extend(tool_results);
321
322 let text_messages = content
324 .into_iter()
325 .filter_map(|content| match content {
326 message::UserContent::Text(text) => Some(Message::User {
327 content: text.text,
328 name: None,
329 }),
330 message::UserContent::Document(Document { data, .. }) => {
331 Some(Message::User {
332 content: data,
333 name: None,
334 })
335 }
336 _ => None,
337 })
338 .collect::<Vec<_>>();
339 messages.extend(text_messages);
340
341 Ok(messages)
342 }
343 message::Message::Assistant { content, .. } => {
344 let mut messages: Vec<Message> = vec![];
345
346 let tool_calls = content
348 .clone()
349 .into_iter()
350 .filter_map(|content| match content {
351 message::AssistantContent::ToolCall(tool_call) => {
352 Some(ToolCall::from(tool_call))
353 }
354 _ => None,
355 })
356 .collect::<Vec<_>>();
357
358 if !tool_calls.is_empty() {
360 messages.push(Message::Assistant {
361 content: "".to_string(),
362 name: None,
363 tool_calls,
364 });
365 }
366
367 let text_content = content
369 .into_iter()
370 .filter_map(|content| match content {
371 message::AssistantContent::Text(text) => Some(Message::Assistant {
372 content: text.text,
373 name: None,
374 tool_calls: vec![],
375 }),
376 _ => None,
377 })
378 .collect::<Vec<_>>();
379
380 messages.extend(text_content);
381
382 Ok(messages)
383 }
384 }
385 }
386}
387
388#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
389pub struct ToolCall {
390 pub id: String,
391 pub index: usize,
392 #[serde(default)]
393 pub r#type: ToolType,
394 pub function: Function,
395}
396
397#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
398pub struct Function {
399 pub name: String,
400 #[serde(with = "json_utils::stringified_json")]
401 pub arguments: serde_json::Value,
402}
403
404#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
405#[serde(rename_all = "lowercase")]
406pub enum ToolType {
407 #[default]
408 Function,
409}
410
411#[derive(Clone, Debug, Deserialize, Serialize)]
412pub struct ToolDefinition {
413 pub r#type: String,
414 pub function: completion::ToolDefinition,
415}
416
417impl From<crate::completion::ToolDefinition> for ToolDefinition {
418 fn from(tool: crate::completion::ToolDefinition) -> Self {
419 Self {
420 r#type: "function".into(),
421 function: tool,
422 }
423 }
424}
425
426impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
427 type Error = CompletionError;
428
429 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
430 let choice = response.choices.first().ok_or_else(|| {
431 CompletionError::ResponseError("Response contained no choices".to_owned())
432 })?;
433 let content = match &choice.message {
434 Message::Assistant {
435 content,
436 tool_calls,
437 ..
438 } => {
439 let mut content = if content.trim().is_empty() {
440 vec![]
441 } else {
442 vec![completion::AssistantContent::text(content)]
443 };
444
445 content.extend(
446 tool_calls
447 .iter()
448 .map(|call| {
449 completion::AssistantContent::tool_call(
450 &call.id,
451 &call.function.name,
452 call.function.arguments.clone(),
453 )
454 })
455 .collect::<Vec<_>>(),
456 );
457 Ok(content)
458 }
459 _ => Err(CompletionError::ResponseError(
460 "Response did not contain a valid message or tool call".into(),
461 )),
462 }?;
463
464 let choice = OneOrMany::many(content).map_err(|_| {
465 CompletionError::ResponseError(
466 "Response contained no message or tool call (empty)".to_owned(),
467 )
468 })?;
469
470 let usage = completion::Usage {
471 input_tokens: response.usage.prompt_tokens as u64,
472 output_tokens: response.usage.completion_tokens as u64,
473 total_tokens: response.usage.total_tokens as u64,
474 };
475
476 Ok(completion::CompletionResponse {
477 choice,
478 usage,
479 raw_response: response,
480 })
481 }
482}
483
484#[derive(Clone)]
486pub struct CompletionModel {
487 pub client: Client,
488 pub model: String,
489}
490
491impl CompletionModel {
492 fn create_completion_request(
493 &self,
494 completion_request: CompletionRequest,
495 ) -> Result<serde_json::Value, CompletionError> {
496 let mut partial_history = vec![];
498
499 if let Some(docs) = completion_request.normalized_documents() {
500 partial_history.push(docs);
501 }
502
503 partial_history.extend(completion_request.chat_history);
504
505 let mut full_history: Vec<Message> = completion_request
507 .preamble
508 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
509
510 full_history.extend(
512 partial_history
513 .into_iter()
514 .map(message::Message::try_into)
515 .collect::<Result<Vec<Vec<Message>>, _>>()?
516 .into_iter()
517 .flatten()
518 .collect::<Vec<_>>(),
519 );
520
521 let request = if completion_request.tools.is_empty() {
522 json!({
523 "model": self.model,
524 "messages": full_history,
525 "temperature": completion_request.temperature,
526 })
527 } else {
528 json!({
529 "model": self.model,
530 "messages": full_history,
531 "temperature": completion_request.temperature,
532 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
533 "tool_choice": "auto",
534 })
535 };
536
537 let request = if let Some(params) = completion_request.additional_params {
538 json_utils::merge(request, params)
539 } else {
540 request
541 };
542
543 Ok(request)
544 }
545}
546
547impl completion::CompletionModel for CompletionModel {
548 type Response = CompletionResponse;
549 type StreamingResponse = StreamingCompletionResponse;
550
551 #[cfg_attr(feature = "worker", worker::send)]
552 async fn completion(
553 &self,
554 completion_request: CompletionRequest,
555 ) -> Result<
556 completion::CompletionResponse<CompletionResponse>,
557 crate::completion::CompletionError,
558 > {
559 let request = self.create_completion_request(completion_request)?;
560
561 tracing::debug!("DeepSeek completion request: {request:?}");
562
563 let response = self
564 .client
565 .post("/chat/completions")
566 .json(&request)
567 .send()
568 .await?;
569
570 if response.status().is_success() {
571 let t = response.text().await?;
572 tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
573
574 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
575 ApiResponse::Ok(response) => response.try_into(),
576 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
577 }
578 } else {
579 Err(CompletionError::ProviderError(response.text().await?))
580 }
581 }
582
583 #[cfg_attr(feature = "worker", worker::send)]
584 async fn stream(
585 &self,
586 completion_request: CompletionRequest,
587 ) -> Result<
588 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
589 CompletionError,
590 > {
591 let mut request = self.create_completion_request(completion_request)?;
592
593 request = merge(
594 request,
595 json!({"stream": true, "stream_options": {"include_usage": true}}),
596 );
597
598 let builder = self.client.post("/chat/completions").json(&request);
599 send_compatible_streaming_request(builder).await
600 }
601}
602
603#[derive(Deserialize, Debug)]
604pub struct StreamingDelta {
605 #[serde(default)]
606 content: Option<String>,
607 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
608 tool_calls: Vec<StreamingToolCall>,
609 reasoning_content: Option<String>,
610}
611
612#[derive(Deserialize, Debug)]
613struct StreamingChoice {
614 delta: StreamingDelta,
615}
616
617#[derive(Deserialize, Debug)]
618struct StreamingCompletionChunk {
619 choices: Vec<StreamingChoice>,
620 usage: Option<Usage>,
621}
622
623#[derive(Clone, Deserialize, Serialize, Debug)]
624pub struct StreamingCompletionResponse {
625 pub usage: Usage,
626}
627
628impl GetTokenUsage for StreamingCompletionResponse {
629 fn token_usage(&self) -> Option<crate::completion::Usage> {
630 let mut usage = crate::completion::Usage::new();
631 usage.input_tokens = self.usage.prompt_tokens as u64;
632 usage.output_tokens = self.usage.completion_tokens as u64;
633 usage.total_tokens = self.usage.total_tokens as u64;
634
635 Some(usage)
636 }
637}
638
639pub async fn send_compatible_streaming_request(
640 request_builder: reqwest::RequestBuilder,
641) -> Result<
642 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
643 CompletionError,
644> {
645 let response = request_builder.send().await?;
646
647 if !response.status().is_success() {
648 return Err(CompletionError::ProviderError(format!(
649 "{}: {}",
650 response.status(),
651 response.text().await?
652 )));
653 }
654
655 let inner = Box::pin(async_stream::stream! {
657 let mut stream = response.bytes_stream();
658
659 let mut final_usage = Usage::new();
660 let mut partial_data = None;
661 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
662
663 while let Some(chunk_result) = stream.next().await {
664 let chunk = match chunk_result {
665 Ok(c) => c,
666 Err(e) => {
667 yield Err(CompletionError::from(e));
668 break;
669 }
670 };
671
672 let text = match String::from_utf8(chunk.to_vec()) {
673 Ok(t) => t,
674 Err(e) => {
675 yield Err(CompletionError::ResponseError(e.to_string()));
676 break;
677 }
678 };
679
680
681 for line in text.lines() {
682 let mut line = line.to_string();
683
684 if partial_data.is_some() {
686 line = format!("{}{}", partial_data.unwrap(), line);
687 partial_data = None;
688 }
689 else {
691 let Some(data) = line.strip_prefix("data:") else {
692 continue;
693 };
694
695 let data = data.trim_start();
696
697 if !line.ends_with("}") {
699 partial_data = Some(data.to_string());
700 } else {
701 line = data.to_string();
702 }
703 }
704
705 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
706
707 let Ok(data) = data else {
708 let err = data.unwrap_err();
709 tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
710 continue;
711 };
712
713
714 if let Some(choice) = data.choices.first() {
715 let delta = &choice.delta;
716
717
718 if !delta.tool_calls.is_empty() {
719 for tool_call in &delta.tool_calls {
720 let function = tool_call.function.clone();
721 if function.name.is_some() && function.arguments.is_empty() {
725 let id = tool_call.id.clone().unwrap_or("".to_string());
726
727 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
728 }
729 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
733 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
734 tracing::debug!("Partial tool call received but tool call was never started.");
735 continue;
736 };
737
738 let new_arguments = &function.arguments;
739 let arguments = format!("{arguments}{new_arguments}");
740
741 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
742 }
743 else {
745 let id = tool_call.id.clone().unwrap_or("".to_string());
746 let name = function.name.expect("function name should be present for complete tool call");
747 let arguments = function.arguments;
748 let Ok(arguments) = serde_json::from_str(&arguments) else {
749 tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
750 continue;
751 };
752
753 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
754 }
755 }
756 }
757
758 if let Some(content) = &delta.reasoning_content {
759 yield Ok(crate::streaming::RawStreamingChoice::Reasoning { reasoning: content.to_string(), id: None})
760 }
761
762 if let Some(content) = &delta.content {
763 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
764 }
765
766 }
767
768
769 if let Some(usage) = data.usage {
770 final_usage = usage.clone();
771 }
772 }
773 }
774
775 for (_, (id, name, arguments)) in calls {
776 let Ok(arguments) = serde_json::from_str(&arguments) else {
777 continue;
778 };
779
780 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
781 }
782
783 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
784 usage: final_usage.clone()
785 }))
786 });
787
788 Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
789}
790
791pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
797pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
799
800#[cfg(test)]
802mod tests {
803
804 use super::*;
805
806 #[test]
807 fn test_deserialize_vec_choice() {
808 let data = r#"[{
809 "finish_reason": "stop",
810 "index": 0,
811 "logprobs": null,
812 "message":{"role":"assistant","content":"Hello, world!"}
813 }]"#;
814
815 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
816 assert_eq!(choices.len(), 1);
817 match &choices.first().unwrap().message {
818 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
819 _ => panic!("Expected assistant message"),
820 }
821 }
822
823 #[test]
824 fn test_deserialize_deepseek_response() {
825 let data = r#"{
826 "choices":[{
827 "finish_reason": "stop",
828 "index": 0,
829 "logprobs": null,
830 "message":{"role":"assistant","content":"Hello, world!"}
831 }],
832 "usage": {
833 "completion_tokens": 0,
834 "prompt_tokens": 0,
835 "prompt_cache_hit_tokens": 0,
836 "prompt_cache_miss_tokens": 0,
837 "total_tokens": 0
838 }
839 }"#;
840
841 let jd = &mut serde_json::Deserializer::from_str(data);
842 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
843 match result {
844 Ok(response) => match &response.choices.first().unwrap().message {
845 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
846 _ => panic!("Expected assistant message"),
847 },
848 Err(err) => {
849 panic!("Deserialization error at {}: {}", err.path(), err);
850 }
851 }
852 }
853
854 #[test]
855 fn test_deserialize_example_response() {
856 let data = r#"
857 {
858 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
859 "object": "chat.completion",
860 "created": 0,
861 "model": "deepseek-chat",
862 "choices": [
863 {
864 "index": 0,
865 "message": {
866 "role": "assistant",
867 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
868 },
869 "logprobs": null,
870 "finish_reason": "stop"
871 }
872 ],
873 "usage": {
874 "prompt_tokens": 13,
875 "completion_tokens": 32,
876 "total_tokens": 45,
877 "prompt_tokens_details": {
878 "cached_tokens": 0
879 },
880 "prompt_cache_hit_tokens": 0,
881 "prompt_cache_miss_tokens": 13
882 },
883 "system_fingerprint": "fp_4b6881f2c5"
884 }
885 "#;
886 let jd = &mut serde_json::Deserializer::from_str(data);
887 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
888
889 match result {
890 Ok(response) => match &response.choices.first().unwrap().message {
891 Message::Assistant { content, .. } => assert_eq!(
892 content,
893 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
894 ),
895 _ => panic!("Expected assistant message"),
896 },
897 Err(err) => {
898 panic!("Deserialization error at {}: {}", err.path(), err);
899 }
900 }
901 }
902
903 #[test]
904 fn test_serialize_deserialize_tool_call_message() {
905 let tool_call_choice_json = r#"
906 {
907 "finish_reason": "tool_calls",
908 "index": 0,
909 "logprobs": null,
910 "message": {
911 "content": "",
912 "role": "assistant",
913 "tool_calls": [
914 {
915 "function": {
916 "arguments": "{\"x\":2,\"y\":5}",
917 "name": "subtract"
918 },
919 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
920 "index": 0,
921 "type": "function"
922 }
923 ]
924 }
925 }
926 "#;
927
928 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
929
930 let expected_choice: Choice = Choice {
931 finish_reason: "tool_calls".to_string(),
932 index: 0,
933 logprobs: None,
934 message: Message::Assistant {
935 content: "".to_string(),
936 name: None,
937 tool_calls: vec![ToolCall {
938 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
939 function: Function {
940 name: "subtract".to_string(),
941 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
942 },
943 index: 0,
944 r#type: ToolType::Function,
945 }],
946 },
947 };
948
949 assert_eq!(choice, expected_choice);
950 }
951}