rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use futures::StreamExt;
13use std::collections::HashMap;
14
15use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
16use crate::completion::GetTokenUsage;
17use crate::json_utils::merge;
18use crate::message::Document;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    impl_conversion_traits, json_utils, message,
23};
24use reqwest::Client as HttpClient;
25use serde::{Deserialize, Serialize};
26use serde_json::json;
27
28use super::openai::StreamingToolCall;
29
30// ================================================================
31// Main DeepSeek Client
32// ================================================================
33const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
34
35pub struct ClientBuilder<'a> {
36    api_key: &'a str,
37    base_url: &'a str,
38    http_client: Option<reqwest::Client>,
39}
40
41impl<'a> ClientBuilder<'a> {
42    pub fn new(api_key: &'a str) -> Self {
43        Self {
44            api_key,
45            base_url: DEEPSEEK_API_BASE_URL,
46            http_client: None,
47        }
48    }
49
50    pub fn base_url(mut self, base_url: &'a str) -> Self {
51        self.base_url = base_url;
52        self
53    }
54
55    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
56        self.http_client = Some(client);
57        self
58    }
59
60    pub fn build(self) -> Result<Client, ClientBuilderError> {
61        let http_client = if let Some(http_client) = self.http_client {
62            http_client
63        } else {
64            reqwest::Client::builder().build()?
65        };
66
67        Ok(Client {
68            base_url: self.base_url.to_string(),
69            api_key: self.api_key.to_string(),
70            http_client,
71        })
72    }
73}
74
75#[derive(Clone)]
76pub struct Client {
77    pub base_url: String,
78    api_key: String,
79    http_client: HttpClient,
80}
81
82impl std::fmt::Debug for Client {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct("Client")
85            .field("base_url", &self.base_url)
86            .field("http_client", &self.http_client)
87            .field("api_key", &"<REDACTED>")
88            .finish()
89    }
90}
91
92impl Client {
93    /// Create a new DeepSeek client builder.
94    ///
95    /// # Example
96    /// ```
97    /// use rig::providers::deepseek::{ClientBuilder, self};
98    ///
99    /// // Initialize the DeepSeek client
100    /// let deepseek = Client::builder("your-deepseek-api-key")
101    ///    .build()
102    /// ```
103    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
104        ClientBuilder::new(api_key)
105    }
106
107    /// Create a new DeepSeek client. For more control, use the `builder` method.
108    ///
109    /// # Panics
110    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
111    pub fn new(api_key: &str) -> Self {
112        Self::builder(api_key)
113            .build()
114            .expect("DeepSeek client should build")
115    }
116
117    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
118        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
119        self.http_client.post(url).bearer_auth(&self.api_key)
120    }
121}
122
123impl ProviderClient for Client {
124    // If you prefer the environment variable approach:
125    fn from_env() -> Self {
126        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
127        Self::new(&api_key)
128    }
129
130    fn from_val(input: crate::client::ProviderValue) -> Self {
131        let crate::client::ProviderValue::Simple(api_key) = input else {
132            panic!("Incorrect provider value type")
133        };
134        Self::new(&api_key)
135    }
136}
137
138impl CompletionClient for Client {
139    type CompletionModel = CompletionModel;
140
141    /// Creates a DeepSeek completion model with the given `model_name`.
142    fn completion_model(&self, model_name: &str) -> CompletionModel {
143        CompletionModel {
144            client: self.clone(),
145            model: model_name.to_string(),
146        }
147    }
148}
149
150impl_conversion_traits!(
151    AsEmbeddings,
152    AsTranscription,
153    AsImageGeneration,
154    AsAudioGeneration for Client
155);
156
157#[derive(Debug, Deserialize)]
158struct ApiErrorResponse {
159    message: String,
160}
161
162#[derive(Debug, Deserialize)]
163#[serde(untagged)]
164enum ApiResponse<T> {
165    Ok(T),
166    Err(ApiErrorResponse),
167}
168
169impl From<ApiErrorResponse> for CompletionError {
170    fn from(err: ApiErrorResponse) -> Self {
171        CompletionError::ProviderError(err.message)
172    }
173}
174
175/// The response shape from the DeepSeek API
176#[derive(Clone, Debug, Serialize, Deserialize)]
177pub struct CompletionResponse {
178    // We'll match the JSON:
179    pub choices: Vec<Choice>,
180    pub usage: Usage,
181    // you may want other fields
182}
183
184#[derive(Clone, Debug, Serialize, Deserialize, Default)]
185pub struct Usage {
186    pub completion_tokens: u32,
187    pub prompt_tokens: u32,
188    pub prompt_cache_hit_tokens: u32,
189    pub prompt_cache_miss_tokens: u32,
190    pub total_tokens: u32,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub completion_tokens_details: Option<CompletionTokensDetails>,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub prompt_tokens_details: Option<PromptTokensDetails>,
195}
196
197impl Usage {
198    fn new() -> Self {
199        Self {
200            completion_tokens: 0,
201            prompt_tokens: 0,
202            prompt_cache_hit_tokens: 0,
203            prompt_cache_miss_tokens: 0,
204            total_tokens: 0,
205            completion_tokens_details: None,
206            prompt_tokens_details: None,
207        }
208    }
209}
210
211#[derive(Clone, Debug, Serialize, Deserialize, Default)]
212pub struct CompletionTokensDetails {
213    #[serde(skip_serializing_if = "Option::is_none")]
214    pub reasoning_tokens: Option<u32>,
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize, Default)]
218pub struct PromptTokensDetails {
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub cached_tokens: Option<u32>,
221}
222
223#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
224pub struct Choice {
225    pub index: usize,
226    pub message: Message,
227    pub logprobs: Option<serde_json::Value>,
228    pub finish_reason: String,
229}
230
231#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
232#[serde(tag = "role", rename_all = "lowercase")]
233pub enum Message {
234    System {
235        content: String,
236        #[serde(skip_serializing_if = "Option::is_none")]
237        name: Option<String>,
238    },
239    User {
240        content: String,
241        #[serde(skip_serializing_if = "Option::is_none")]
242        name: Option<String>,
243    },
244    Assistant {
245        content: String,
246        #[serde(skip_serializing_if = "Option::is_none")]
247        name: Option<String>,
248        #[serde(
249            default,
250            deserialize_with = "json_utils::null_or_vec",
251            skip_serializing_if = "Vec::is_empty"
252        )]
253        tool_calls: Vec<ToolCall>,
254    },
255    #[serde(rename = "tool")]
256    ToolResult {
257        tool_call_id: String,
258        content: String,
259    },
260}
261
262impl Message {
263    pub fn system(content: &str) -> Self {
264        Message::System {
265            content: content.to_owned(),
266            name: None,
267        }
268    }
269}
270
271impl From<message::ToolResult> for Message {
272    fn from(tool_result: message::ToolResult) -> Self {
273        let content = match tool_result.content.first() {
274            message::ToolResultContent::Text(text) => text.text,
275            message::ToolResultContent::Image(_) => String::from("[Image]"),
276        };
277
278        Message::ToolResult {
279            tool_call_id: tool_result.id,
280            content,
281        }
282    }
283}
284
285impl From<message::ToolCall> for ToolCall {
286    fn from(tool_call: message::ToolCall) -> Self {
287        Self {
288            id: tool_call.id,
289            // TODO: update index when we have it
290            index: 0,
291            r#type: ToolType::Function,
292            function: Function {
293                name: tool_call.function.name,
294                arguments: tool_call.function.arguments,
295            },
296        }
297    }
298}
299
300impl TryFrom<message::Message> for Vec<Message> {
301    type Error = message::MessageError;
302
303    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
304        match message {
305            message::Message::User { content } => {
306                // extract tool results
307                let mut messages = vec![];
308
309                let tool_results = content
310                    .clone()
311                    .into_iter()
312                    .filter_map(|content| match content {
313                        message::UserContent::ToolResult(tool_result) => {
314                            Some(Message::from(tool_result))
315                        }
316                        _ => None,
317                    })
318                    .collect::<Vec<_>>();
319
320                messages.extend(tool_results);
321
322                // extract text results
323                let text_messages = content
324                    .into_iter()
325                    .filter_map(|content| match content {
326                        message::UserContent::Text(text) => Some(Message::User {
327                            content: text.text,
328                            name: None,
329                        }),
330                        message::UserContent::Document(Document { data, .. }) => {
331                            Some(Message::User {
332                                content: data,
333                                name: None,
334                            })
335                        }
336                        _ => None,
337                    })
338                    .collect::<Vec<_>>();
339                messages.extend(text_messages);
340
341                Ok(messages)
342            }
343            message::Message::Assistant { content, .. } => {
344                let mut messages: Vec<Message> = vec![];
345
346                // extract tool calls
347                let tool_calls = content
348                    .clone()
349                    .into_iter()
350                    .filter_map(|content| match content {
351                        message::AssistantContent::ToolCall(tool_call) => {
352                            Some(ToolCall::from(tool_call))
353                        }
354                        _ => None,
355                    })
356                    .collect::<Vec<_>>();
357
358                // if we have tool calls, we add a new Assistant message with them
359                if !tool_calls.is_empty() {
360                    messages.push(Message::Assistant {
361                        content: "".to_string(),
362                        name: None,
363                        tool_calls,
364                    });
365                }
366
367                // extract text
368                let text_content = content
369                    .into_iter()
370                    .filter_map(|content| match content {
371                        message::AssistantContent::Text(text) => Some(Message::Assistant {
372                            content: text.text,
373                            name: None,
374                            tool_calls: vec![],
375                        }),
376                        _ => None,
377                    })
378                    .collect::<Vec<_>>();
379
380                messages.extend(text_content);
381
382                Ok(messages)
383            }
384        }
385    }
386}
387
388#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
389pub struct ToolCall {
390    pub id: String,
391    pub index: usize,
392    #[serde(default)]
393    pub r#type: ToolType,
394    pub function: Function,
395}
396
397#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
398pub struct Function {
399    pub name: String,
400    #[serde(with = "json_utils::stringified_json")]
401    pub arguments: serde_json::Value,
402}
403
404#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
405#[serde(rename_all = "lowercase")]
406pub enum ToolType {
407    #[default]
408    Function,
409}
410
411#[derive(Clone, Debug, Deserialize, Serialize)]
412pub struct ToolDefinition {
413    pub r#type: String,
414    pub function: completion::ToolDefinition,
415}
416
417impl From<crate::completion::ToolDefinition> for ToolDefinition {
418    fn from(tool: crate::completion::ToolDefinition) -> Self {
419        Self {
420            r#type: "function".into(),
421            function: tool,
422        }
423    }
424}
425
426impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
427    type Error = CompletionError;
428
429    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
430        let choice = response.choices.first().ok_or_else(|| {
431            CompletionError::ResponseError("Response contained no choices".to_owned())
432        })?;
433        let content = match &choice.message {
434            Message::Assistant {
435                content,
436                tool_calls,
437                ..
438            } => {
439                let mut content = if content.trim().is_empty() {
440                    vec![]
441                } else {
442                    vec![completion::AssistantContent::text(content)]
443                };
444
445                content.extend(
446                    tool_calls
447                        .iter()
448                        .map(|call| {
449                            completion::AssistantContent::tool_call(
450                                &call.id,
451                                &call.function.name,
452                                call.function.arguments.clone(),
453                            )
454                        })
455                        .collect::<Vec<_>>(),
456                );
457                Ok(content)
458            }
459            _ => Err(CompletionError::ResponseError(
460                "Response did not contain a valid message or tool call".into(),
461            )),
462        }?;
463
464        let choice = OneOrMany::many(content).map_err(|_| {
465            CompletionError::ResponseError(
466                "Response contained no message or tool call (empty)".to_owned(),
467            )
468        })?;
469
470        let usage = completion::Usage {
471            input_tokens: response.usage.prompt_tokens as u64,
472            output_tokens: response.usage.completion_tokens as u64,
473            total_tokens: response.usage.total_tokens as u64,
474        };
475
476        Ok(completion::CompletionResponse {
477            choice,
478            usage,
479            raw_response: response,
480        })
481    }
482}
483
484/// The struct implementing the `CompletionModel` trait
485#[derive(Clone)]
486pub struct CompletionModel {
487    pub client: Client,
488    pub model: String,
489}
490
491impl CompletionModel {
492    fn create_completion_request(
493        &self,
494        completion_request: CompletionRequest,
495    ) -> Result<serde_json::Value, CompletionError> {
496        // Build up the order of messages (context, chat_history, prompt)
497        let mut partial_history = vec![];
498
499        if let Some(docs) = completion_request.normalized_documents() {
500            partial_history.push(docs);
501        }
502
503        partial_history.extend(completion_request.chat_history);
504
505        // Initialize full history with preamble (or empty if non-existent)
506        let mut full_history: Vec<Message> = completion_request
507            .preamble
508            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
509
510        // Convert and extend the rest of the history
511        full_history.extend(
512            partial_history
513                .into_iter()
514                .map(message::Message::try_into)
515                .collect::<Result<Vec<Vec<Message>>, _>>()?
516                .into_iter()
517                .flatten()
518                .collect::<Vec<_>>(),
519        );
520
521        let request = if completion_request.tools.is_empty() {
522            json!({
523                "model": self.model,
524                "messages": full_history,
525                "temperature": completion_request.temperature,
526            })
527        } else {
528            json!({
529                "model": self.model,
530                "messages": full_history,
531                "temperature": completion_request.temperature,
532                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
533                "tool_choice": "auto",
534            })
535        };
536
537        let request = if let Some(params) = completion_request.additional_params {
538            json_utils::merge(request, params)
539        } else {
540            request
541        };
542
543        Ok(request)
544    }
545}
546
547impl completion::CompletionModel for CompletionModel {
548    type Response = CompletionResponse;
549    type StreamingResponse = StreamingCompletionResponse;
550
551    #[cfg_attr(feature = "worker", worker::send)]
552    async fn completion(
553        &self,
554        completion_request: CompletionRequest,
555    ) -> Result<
556        completion::CompletionResponse<CompletionResponse>,
557        crate::completion::CompletionError,
558    > {
559        let request = self.create_completion_request(completion_request)?;
560
561        tracing::debug!("DeepSeek completion request: {request:?}");
562
563        let response = self
564            .client
565            .post("/chat/completions")
566            .json(&request)
567            .send()
568            .await?;
569
570        if response.status().is_success() {
571            let t = response.text().await?;
572            tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
573
574            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
575                ApiResponse::Ok(response) => response.try_into(),
576                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
577            }
578        } else {
579            Err(CompletionError::ProviderError(response.text().await?))
580        }
581    }
582
583    #[cfg_attr(feature = "worker", worker::send)]
584    async fn stream(
585        &self,
586        completion_request: CompletionRequest,
587    ) -> Result<
588        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
589        CompletionError,
590    > {
591        let mut request = self.create_completion_request(completion_request)?;
592
593        request = merge(
594            request,
595            json!({"stream": true, "stream_options": {"include_usage": true}}),
596        );
597
598        let builder = self.client.post("/chat/completions").json(&request);
599        send_compatible_streaming_request(builder).await
600    }
601}
602
603#[derive(Deserialize, Debug)]
604pub struct StreamingDelta {
605    #[serde(default)]
606    content: Option<String>,
607    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
608    tool_calls: Vec<StreamingToolCall>,
609    reasoning_content: Option<String>,
610}
611
612#[derive(Deserialize, Debug)]
613struct StreamingChoice {
614    delta: StreamingDelta,
615}
616
617#[derive(Deserialize, Debug)]
618struct StreamingCompletionChunk {
619    choices: Vec<StreamingChoice>,
620    usage: Option<Usage>,
621}
622
623#[derive(Clone, Deserialize, Serialize, Debug)]
624pub struct StreamingCompletionResponse {
625    pub usage: Usage,
626}
627
628impl GetTokenUsage for StreamingCompletionResponse {
629    fn token_usage(&self) -> Option<crate::completion::Usage> {
630        let mut usage = crate::completion::Usage::new();
631        usage.input_tokens = self.usage.prompt_tokens as u64;
632        usage.output_tokens = self.usage.completion_tokens as u64;
633        usage.total_tokens = self.usage.total_tokens as u64;
634
635        Some(usage)
636    }
637}
638
639pub async fn send_compatible_streaming_request(
640    request_builder: reqwest::RequestBuilder,
641) -> Result<
642    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
643    CompletionError,
644> {
645    let response = request_builder.send().await?;
646
647    if !response.status().is_success() {
648        return Err(CompletionError::ProviderError(format!(
649            "{}: {}",
650            response.status(),
651            response.text().await?
652        )));
653    }
654
655    // Handle OpenAI Compatible SSE chunks
656    let inner = Box::pin(async_stream::stream! {
657        let mut stream = response.bytes_stream();
658
659        let mut final_usage = Usage::new();
660        let mut partial_data = None;
661        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
662
663        while let Some(chunk_result) = stream.next().await {
664            let chunk = match chunk_result {
665                Ok(c) => c,
666                Err(e) => {
667                    yield Err(CompletionError::from(e));
668                    break;
669                }
670            };
671
672            let text = match String::from_utf8(chunk.to_vec()) {
673                Ok(t) => t,
674                Err(e) => {
675                    yield Err(CompletionError::ResponseError(e.to_string()));
676                    break;
677                }
678            };
679
680
681            for line in text.lines() {
682                let mut line = line.to_string();
683
684                // If there was a remaining part, concat with current line
685                if partial_data.is_some() {
686                    line = format!("{}{}", partial_data.unwrap(), line);
687                    partial_data = None;
688                }
689                // Otherwise full data line
690                else {
691                    let Some(data) = line.strip_prefix("data:") else {
692                        continue;
693                    };
694
695                    let data = data.trim_start();
696
697                    // Partial data, split somewhere in the middle
698                    if !line.ends_with("}") {
699                        partial_data = Some(data.to_string());
700                    } else {
701                        line = data.to_string();
702                    }
703                }
704
705                let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
706
707                let Ok(data) = data else {
708                    let err = data.unwrap_err();
709                    tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
710                    continue;
711                };
712
713
714                if let Some(choice) = data.choices.first() {
715                    let delta = &choice.delta;
716
717
718                            if !delta.tool_calls.is_empty() {
719                                for tool_call in &delta.tool_calls {
720                                    let function = tool_call.function.clone();
721                                    // Start of tool call
722                                    // name: Some(String)
723                                    // arguments: None
724                                    if function.name.is_some() && function.arguments.is_empty() {
725                                        let id = tool_call.id.clone().unwrap_or("".to_string());
726
727                                        calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
728                                    }
729                                    // Part of tool call
730                                    // name: None or Empty String
731                                    // arguments: Some(String)
732                                    else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
733                                        let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
734                                            tracing::debug!("Partial tool call received but tool call was never started.");
735                                            continue;
736                                        };
737
738                                        let new_arguments = &function.arguments;
739                                        let arguments = format!("{arguments}{new_arguments}");
740
741                                        calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
742                                    }
743                                    // Entire tool call
744                                    else {
745                                        let id = tool_call.id.clone().unwrap_or("".to_string());
746                                        let name = function.name.expect("function name should be present for complete tool call");
747                                        let arguments = function.arguments;
748                                        let Ok(arguments) = serde_json::from_str(&arguments) else {
749                                            tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
750                                            continue;
751                                        };
752
753                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
754                                    }
755                                }
756                            }
757
758                            if let Some(content) = &delta.reasoning_content {
759                                yield Ok(crate::streaming::RawStreamingChoice::Reasoning { reasoning: content.to_string(), id: None})
760                            }
761
762                            if let Some(content) = &delta.content {
763                                yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
764                            }
765
766                }
767
768
769                if let Some(usage) = data.usage {
770                    final_usage = usage.clone();
771                }
772            }
773        }
774
775        for (_, (id, name, arguments)) in calls {
776            let Ok(arguments) = serde_json::from_str(&arguments) else {
777                continue;
778            };
779
780            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
781        }
782
783        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
784            usage: final_usage.clone()
785        }))
786    });
787
788    Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
789}
790
791// ================================================================
792// DeepSeek Completion API
793// ================================================================
794
795/// `deepseek-chat` completion model
796pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
797/// `deepseek-reasoner` completion model
798pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
799
800// Tests
801#[cfg(test)]
802mod tests {
803
804    use super::*;
805
806    #[test]
807    fn test_deserialize_vec_choice() {
808        let data = r#"[{
809            "finish_reason": "stop",
810            "index": 0,
811            "logprobs": null,
812            "message":{"role":"assistant","content":"Hello, world!"}
813            }]"#;
814
815        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
816        assert_eq!(choices.len(), 1);
817        match &choices.first().unwrap().message {
818            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
819            _ => panic!("Expected assistant message"),
820        }
821    }
822
823    #[test]
824    fn test_deserialize_deepseek_response() {
825        let data = r#"{
826            "choices":[{
827                "finish_reason": "stop",
828                "index": 0,
829                "logprobs": null,
830                "message":{"role":"assistant","content":"Hello, world!"}
831            }],
832            "usage": {
833                "completion_tokens": 0,
834                "prompt_tokens": 0,
835                "prompt_cache_hit_tokens": 0,
836                "prompt_cache_miss_tokens": 0,
837                "total_tokens": 0
838            }
839        }"#;
840
841        let jd = &mut serde_json::Deserializer::from_str(data);
842        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
843        match result {
844            Ok(response) => match &response.choices.first().unwrap().message {
845                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
846                _ => panic!("Expected assistant message"),
847            },
848            Err(err) => {
849                panic!("Deserialization error at {}: {}", err.path(), err);
850            }
851        }
852    }
853
854    #[test]
855    fn test_deserialize_example_response() {
856        let data = r#"
857        {
858            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
859            "object": "chat.completion",
860            "created": 0,
861            "model": "deepseek-chat",
862            "choices": [
863                {
864                    "index": 0,
865                    "message": {
866                        "role": "assistant",
867                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
868                    },
869                    "logprobs": null,
870                    "finish_reason": "stop"
871                }
872            ],
873            "usage": {
874                "prompt_tokens": 13,
875                "completion_tokens": 32,
876                "total_tokens": 45,
877                "prompt_tokens_details": {
878                    "cached_tokens": 0
879                },
880                "prompt_cache_hit_tokens": 0,
881                "prompt_cache_miss_tokens": 13
882            },
883            "system_fingerprint": "fp_4b6881f2c5"
884        }
885        "#;
886        let jd = &mut serde_json::Deserializer::from_str(data);
887        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
888
889        match result {
890            Ok(response) => match &response.choices.first().unwrap().message {
891                Message::Assistant { content, .. } => assert_eq!(
892                    content,
893                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
894                ),
895                _ => panic!("Expected assistant message"),
896            },
897            Err(err) => {
898                panic!("Deserialization error at {}: {}", err.path(), err);
899            }
900        }
901    }
902
903    #[test]
904    fn test_serialize_deserialize_tool_call_message() {
905        let tool_call_choice_json = r#"
906            {
907              "finish_reason": "tool_calls",
908              "index": 0,
909              "logprobs": null,
910              "message": {
911                "content": "",
912                "role": "assistant",
913                "tool_calls": [
914                  {
915                    "function": {
916                      "arguments": "{\"x\":2,\"y\":5}",
917                      "name": "subtract"
918                    },
919                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
920                    "index": 0,
921                    "type": "function"
922                  }
923                ]
924              }
925            }
926        "#;
927
928        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
929
930        let expected_choice: Choice = Choice {
931            finish_reason: "tool_calls".to_string(),
932            index: 0,
933            logprobs: None,
934            message: Message::Assistant {
935                content: "".to_string(),
936                name: None,
937                tool_calls: vec![ToolCall {
938                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
939                    function: Function {
940                        name: "subtract".to_string(),
941                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
942                    },
943                    index: 0,
944                    r#type: ToolType::Function,
945                }],
946            },
947        };
948
949        assert_eq!(choice, expected_choice);
950    }
951}