rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use futures::StreamExt;
13use std::collections::HashMap;
14
15use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
16use crate::json_utils::merge;
17use crate::message::Document;
18use crate::{
19    OneOrMany,
20    completion::{self, CompletionError, CompletionRequest},
21    impl_conversion_traits, json_utils, message,
22};
23use reqwest::Client as HttpClient;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26
27use super::openai::StreamingToolCall;
28
29// ================================================================
30// Main DeepSeek Client
31// ================================================================
32const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
33
34pub struct ClientBuilder<'a> {
35    api_key: &'a str,
36    base_url: &'a str,
37    http_client: Option<reqwest::Client>,
38}
39
40impl<'a> ClientBuilder<'a> {
41    pub fn new(api_key: &'a str) -> Self {
42        Self {
43            api_key,
44            base_url: DEEPSEEK_API_BASE_URL,
45            http_client: None,
46        }
47    }
48
49    pub fn base_url(mut self, base_url: &'a str) -> Self {
50        self.base_url = base_url;
51        self
52    }
53
54    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
55        self.http_client = Some(client);
56        self
57    }
58
59    pub fn build(self) -> Result<Client, ClientBuilderError> {
60        let http_client = if let Some(http_client) = self.http_client {
61            http_client
62        } else {
63            reqwest::Client::builder().build()?
64        };
65
66        Ok(Client {
67            base_url: self.base_url.to_string(),
68            api_key: self.api_key.to_string(),
69            http_client,
70        })
71    }
72}
73
74#[derive(Clone)]
75pub struct Client {
76    pub base_url: String,
77    api_key: String,
78    http_client: HttpClient,
79}
80
81impl std::fmt::Debug for Client {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("Client")
84            .field("base_url", &self.base_url)
85            .field("http_client", &self.http_client)
86            .field("api_key", &"<REDACTED>")
87            .finish()
88    }
89}
90
91impl Client {
92    /// Create a new DeepSeek client builder.
93    ///
94    /// # Example
95    /// ```
96    /// use rig::providers::deepseek::{ClientBuilder, self};
97    ///
98    /// // Initialize the DeepSeek client
99    /// let deepseek = Client::builder("your-deepseek-api-key")
100    ///    .build()
101    /// ```
102    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
103        ClientBuilder::new(api_key)
104    }
105
106    /// Create a new DeepSeek client. For more control, use the `builder` method.
107    ///
108    /// # Panics
109    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
110    pub fn new(api_key: &str) -> Self {
111        Self::builder(api_key)
112            .build()
113            .expect("DeepSeek client should build")
114    }
115
116    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
117        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
118        self.http_client.post(url).bearer_auth(&self.api_key)
119    }
120}
121
122impl ProviderClient for Client {
123    // If you prefer the environment variable approach:
124    fn from_env() -> Self {
125        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
126        Self::new(&api_key)
127    }
128
129    fn from_val(input: crate::client::ProviderValue) -> Self {
130        let crate::client::ProviderValue::Simple(api_key) = input else {
131            panic!("Incorrect provider value type")
132        };
133        Self::new(&api_key)
134    }
135}
136
137impl CompletionClient for Client {
138    type CompletionModel = CompletionModel;
139
140    /// Creates a DeepSeek completion model with the given `model_name`.
141    fn completion_model(&self, model_name: &str) -> CompletionModel {
142        CompletionModel {
143            client: self.clone(),
144            model: model_name.to_string(),
145        }
146    }
147}
148
149impl_conversion_traits!(
150    AsEmbeddings,
151    AsTranscription,
152    AsImageGeneration,
153    AsAudioGeneration for Client
154);
155
156#[derive(Debug, Deserialize)]
157struct ApiErrorResponse {
158    message: String,
159}
160
161#[derive(Debug, Deserialize)]
162#[serde(untagged)]
163enum ApiResponse<T> {
164    Ok(T),
165    Err(ApiErrorResponse),
166}
167
168impl From<ApiErrorResponse> for CompletionError {
169    fn from(err: ApiErrorResponse) -> Self {
170        CompletionError::ProviderError(err.message)
171    }
172}
173
174/// The response shape from the DeepSeek API
175#[derive(Clone, Debug, Serialize, Deserialize)]
176pub struct CompletionResponse {
177    // We'll match the JSON:
178    pub choices: Vec<Choice>,
179    pub usage: Usage,
180    // you may want other fields
181}
182
183#[derive(Clone, Debug, Serialize, Deserialize, Default)]
184pub struct Usage {
185    pub completion_tokens: u32,
186    pub prompt_tokens: u32,
187    pub prompt_cache_hit_tokens: u32,
188    pub prompt_cache_miss_tokens: u32,
189    pub total_tokens: u32,
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub completion_tokens_details: Option<CompletionTokensDetails>,
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub prompt_tokens_details: Option<PromptTokensDetails>,
194}
195
196impl Usage {
197    fn new() -> Self {
198        Self {
199            completion_tokens: 0,
200            prompt_tokens: 0,
201            prompt_cache_hit_tokens: 0,
202            prompt_cache_miss_tokens: 0,
203            total_tokens: 0,
204            completion_tokens_details: None,
205            prompt_tokens_details: None,
206        }
207    }
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize, Default)]
211pub struct CompletionTokensDetails {
212    #[serde(skip_serializing_if = "Option::is_none")]
213    pub reasoning_tokens: Option<u32>,
214}
215
216#[derive(Clone, Debug, Serialize, Deserialize, Default)]
217pub struct PromptTokensDetails {
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub cached_tokens: Option<u32>,
220}
221
222#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
223pub struct Choice {
224    pub index: usize,
225    pub message: Message,
226    pub logprobs: Option<serde_json::Value>,
227    pub finish_reason: String,
228}
229
230#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
231#[serde(tag = "role", rename_all = "lowercase")]
232pub enum Message {
233    System {
234        content: String,
235        #[serde(skip_serializing_if = "Option::is_none")]
236        name: Option<String>,
237    },
238    User {
239        content: String,
240        #[serde(skip_serializing_if = "Option::is_none")]
241        name: Option<String>,
242    },
243    Assistant {
244        content: String,
245        #[serde(skip_serializing_if = "Option::is_none")]
246        name: Option<String>,
247        #[serde(
248            default,
249            deserialize_with = "json_utils::null_or_vec",
250            skip_serializing_if = "Vec::is_empty"
251        )]
252        tool_calls: Vec<ToolCall>,
253    },
254    #[serde(rename = "tool")]
255    ToolResult {
256        tool_call_id: String,
257        content: String,
258    },
259}
260
261impl Message {
262    pub fn system(content: &str) -> Self {
263        Message::System {
264            content: content.to_owned(),
265            name: None,
266        }
267    }
268}
269
270impl From<message::ToolResult> for Message {
271    fn from(tool_result: message::ToolResult) -> Self {
272        let content = match tool_result.content.first() {
273            message::ToolResultContent::Text(text) => text.text,
274            message::ToolResultContent::Image(_) => String::from("[Image]"),
275        };
276
277        Message::ToolResult {
278            tool_call_id: tool_result.id,
279            content,
280        }
281    }
282}
283
284impl From<message::ToolCall> for ToolCall {
285    fn from(tool_call: message::ToolCall) -> Self {
286        Self {
287            id: tool_call.id,
288            // TODO: update index when we have it
289            index: 0,
290            r#type: ToolType::Function,
291            function: Function {
292                name: tool_call.function.name,
293                arguments: tool_call.function.arguments,
294            },
295        }
296    }
297}
298
299impl TryFrom<message::Message> for Vec<Message> {
300    type Error = message::MessageError;
301
302    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
303        match message {
304            message::Message::User { content } => {
305                // extract tool results
306                let mut messages = vec![];
307
308                let tool_results = content
309                    .clone()
310                    .into_iter()
311                    .filter_map(|content| match content {
312                        message::UserContent::ToolResult(tool_result) => {
313                            Some(Message::from(tool_result))
314                        }
315                        _ => None,
316                    })
317                    .collect::<Vec<_>>();
318
319                messages.extend(tool_results);
320
321                // extract text results
322                let text_messages = content
323                    .into_iter()
324                    .filter_map(|content| match content {
325                        message::UserContent::Text(text) => Some(Message::User {
326                            content: text.text,
327                            name: None,
328                        }),
329                        message::UserContent::Document(Document { data, .. }) => {
330                            Some(Message::User {
331                                content: data,
332                                name: None,
333                            })
334                        }
335                        _ => None,
336                    })
337                    .collect::<Vec<_>>();
338                messages.extend(text_messages);
339
340                Ok(messages)
341            }
342            message::Message::Assistant { content, .. } => {
343                let mut messages: Vec<Message> = vec![];
344
345                // extract tool calls
346                let tool_calls = content
347                    .clone()
348                    .into_iter()
349                    .filter_map(|content| match content {
350                        message::AssistantContent::ToolCall(tool_call) => {
351                            Some(ToolCall::from(tool_call))
352                        }
353                        _ => None,
354                    })
355                    .collect::<Vec<_>>();
356
357                // if we have tool calls, we add a new Assistant message with them
358                if !tool_calls.is_empty() {
359                    messages.push(Message::Assistant {
360                        content: "".to_string(),
361                        name: None,
362                        tool_calls,
363                    });
364                }
365
366                // extract text
367                let text_content = content
368                    .into_iter()
369                    .filter_map(|content| match content {
370                        message::AssistantContent::Text(text) => Some(Message::Assistant {
371                            content: text.text,
372                            name: None,
373                            tool_calls: vec![],
374                        }),
375                        _ => None,
376                    })
377                    .collect::<Vec<_>>();
378
379                messages.extend(text_content);
380
381                Ok(messages)
382            }
383        }
384    }
385}
386
387#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
388pub struct ToolCall {
389    pub id: String,
390    pub index: usize,
391    #[serde(default)]
392    pub r#type: ToolType,
393    pub function: Function,
394}
395
396#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
397pub struct Function {
398    pub name: String,
399    #[serde(with = "json_utils::stringified_json")]
400    pub arguments: serde_json::Value,
401}
402
403#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
404#[serde(rename_all = "lowercase")]
405pub enum ToolType {
406    #[default]
407    Function,
408}
409
410#[derive(Clone, Debug, Deserialize, Serialize)]
411pub struct ToolDefinition {
412    pub r#type: String,
413    pub function: completion::ToolDefinition,
414}
415
416impl From<crate::completion::ToolDefinition> for ToolDefinition {
417    fn from(tool: crate::completion::ToolDefinition) -> Self {
418        Self {
419            r#type: "function".into(),
420            function: tool,
421        }
422    }
423}
424
425impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
426    type Error = CompletionError;
427
428    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
429        let choice = response.choices.first().ok_or_else(|| {
430            CompletionError::ResponseError("Response contained no choices".to_owned())
431        })?;
432        let content = match &choice.message {
433            Message::Assistant {
434                content,
435                tool_calls,
436                ..
437            } => {
438                let mut content = if content.trim().is_empty() {
439                    vec![]
440                } else {
441                    vec![completion::AssistantContent::text(content)]
442                };
443
444                content.extend(
445                    tool_calls
446                        .iter()
447                        .map(|call| {
448                            completion::AssistantContent::tool_call(
449                                &call.id,
450                                &call.function.name,
451                                call.function.arguments.clone(),
452                            )
453                        })
454                        .collect::<Vec<_>>(),
455                );
456                Ok(content)
457            }
458            _ => Err(CompletionError::ResponseError(
459                "Response did not contain a valid message or tool call".into(),
460            )),
461        }?;
462
463        let choice = OneOrMany::many(content).map_err(|_| {
464            CompletionError::ResponseError(
465                "Response contained no message or tool call (empty)".to_owned(),
466            )
467        })?;
468
469        let usage = completion::Usage {
470            input_tokens: response.usage.prompt_tokens as u64,
471            output_tokens: response.usage.completion_tokens as u64,
472            total_tokens: response.usage.total_tokens as u64,
473        };
474
475        Ok(completion::CompletionResponse {
476            choice,
477            usage,
478            raw_response: response,
479        })
480    }
481}
482
483/// The struct implementing the `CompletionModel` trait
484#[derive(Clone)]
485pub struct CompletionModel {
486    pub client: Client,
487    pub model: String,
488}
489
490impl CompletionModel {
491    fn create_completion_request(
492        &self,
493        completion_request: CompletionRequest,
494    ) -> Result<serde_json::Value, CompletionError> {
495        // Build up the order of messages (context, chat_history, prompt)
496        let mut partial_history = vec![];
497
498        if let Some(docs) = completion_request.normalized_documents() {
499            partial_history.push(docs);
500        }
501
502        partial_history.extend(completion_request.chat_history);
503
504        // Initialize full history with preamble (or empty if non-existent)
505        let mut full_history: Vec<Message> = completion_request
506            .preamble
507            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
508
509        // Convert and extend the rest of the history
510        full_history.extend(
511            partial_history
512                .into_iter()
513                .map(message::Message::try_into)
514                .collect::<Result<Vec<Vec<Message>>, _>>()?
515                .into_iter()
516                .flatten()
517                .collect::<Vec<_>>(),
518        );
519
520        let request = if completion_request.tools.is_empty() {
521            json!({
522                "model": self.model,
523                "messages": full_history,
524                "temperature": completion_request.temperature,
525            })
526        } else {
527            json!({
528                "model": self.model,
529                "messages": full_history,
530                "temperature": completion_request.temperature,
531                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
532                "tool_choice": "auto",
533            })
534        };
535
536        let request = if let Some(params) = completion_request.additional_params {
537            json_utils::merge(request, params)
538        } else {
539            request
540        };
541
542        Ok(request)
543    }
544}
545
546impl completion::CompletionModel for CompletionModel {
547    type Response = CompletionResponse;
548    type StreamingResponse = StreamingCompletionResponse;
549
550    #[cfg_attr(feature = "worker", worker::send)]
551    async fn completion(
552        &self,
553        completion_request: CompletionRequest,
554    ) -> Result<
555        completion::CompletionResponse<CompletionResponse>,
556        crate::completion::CompletionError,
557    > {
558        let request = self.create_completion_request(completion_request)?;
559
560        tracing::debug!("DeepSeek completion request: {request:?}");
561
562        let response = self
563            .client
564            .post("/chat/completions")
565            .json(&request)
566            .send()
567            .await?;
568
569        if response.status().is_success() {
570            let t = response.text().await?;
571            tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
572
573            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
574                ApiResponse::Ok(response) => response.try_into(),
575                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
576            }
577        } else {
578            Err(CompletionError::ProviderError(response.text().await?))
579        }
580    }
581
582    #[cfg_attr(feature = "worker", worker::send)]
583    async fn stream(
584        &self,
585        completion_request: CompletionRequest,
586    ) -> Result<
587        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
588        CompletionError,
589    > {
590        let mut request = self.create_completion_request(completion_request)?;
591
592        request = merge(
593            request,
594            json!({"stream": true, "stream_options": {"include_usage": true}}),
595        );
596
597        let builder = self.client.post("/v1/chat/completions").json(&request);
598        send_compatible_streaming_request(builder).await
599    }
600}
601
602#[derive(Deserialize, Debug)]
603pub struct StreamingDelta {
604    #[serde(default)]
605    content: Option<String>,
606    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
607    tool_calls: Vec<StreamingToolCall>,
608    reasoning_content: Option<String>,
609}
610
611#[derive(Deserialize, Debug)]
612struct StreamingChoice {
613    delta: StreamingDelta,
614}
615
616#[derive(Deserialize, Debug)]
617struct StreamingCompletionChunk {
618    choices: Vec<StreamingChoice>,
619    usage: Option<Usage>,
620}
621
622#[derive(Clone, Deserialize, Serialize, Debug)]
623pub struct StreamingCompletionResponse {
624    pub usage: Usage,
625}
626
627pub async fn send_compatible_streaming_request(
628    request_builder: reqwest::RequestBuilder,
629) -> Result<
630    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
631    CompletionError,
632> {
633    let response = request_builder.send().await?;
634
635    if !response.status().is_success() {
636        return Err(CompletionError::ProviderError(format!(
637            "{}: {}",
638            response.status(),
639            response.text().await?
640        )));
641    }
642
643    // Handle OpenAI Compatible SSE chunks
644    let inner = Box::pin(async_stream::stream! {
645        let mut stream = response.bytes_stream();
646
647        let mut final_usage = Usage::new();
648        let mut partial_data = None;
649        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
650
651        while let Some(chunk_result) = stream.next().await {
652            let chunk = match chunk_result {
653                Ok(c) => c,
654                Err(e) => {
655                    yield Err(CompletionError::from(e));
656                    break;
657                }
658            };
659
660            let text = match String::from_utf8(chunk.to_vec()) {
661                Ok(t) => t,
662                Err(e) => {
663                    yield Err(CompletionError::ResponseError(e.to_string()));
664                    break;
665                }
666            };
667
668
669            for line in text.lines() {
670                let mut line = line.to_string();
671
672                // If there was a remaining part, concat with current line
673                if partial_data.is_some() {
674                    line = format!("{}{}", partial_data.unwrap(), line);
675                    partial_data = None;
676                }
677                // Otherwise full data line
678                else {
679                    let Some(data) = line.strip_prefix("data:") else {
680                        continue;
681                    };
682
683                    let data = data.trim_start();
684
685                    // Partial data, split somewhere in the middle
686                    if !line.ends_with("}") {
687                        partial_data = Some(data.to_string());
688                    } else {
689                        line = data.to_string();
690                    }
691                }
692
693                let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
694
695                let Ok(data) = data else {
696                    let err = data.unwrap_err();
697                    tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
698                    continue;
699                };
700
701
702                if let Some(choice) = data.choices.first() {
703                    let delta = &choice.delta;
704
705
706                            if !delta.tool_calls.is_empty() {
707                                for tool_call in &delta.tool_calls {
708                                    let function = tool_call.function.clone();
709                                    // Start of tool call
710                                    // name: Some(String)
711                                    // arguments: None
712                                    if function.name.is_some() && function.arguments.is_empty() {
713                                        let id = tool_call.id.clone().unwrap_or("".to_string());
714
715                                        calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
716                                    }
717                                    // Part of tool call
718                                    // name: None or Empty String
719                                    // arguments: Some(String)
720                                    else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
721                                        let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
722                                            tracing::debug!("Partial tool call received but tool call was never started.");
723                                            continue;
724                                        };
725
726                                        let new_arguments = &function.arguments;
727                                        let arguments = format!("{arguments}{new_arguments}");
728
729                                        calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
730                                    }
731                                    // Entire tool call
732                                    else {
733                                        let id = tool_call.id.clone().unwrap_or("".to_string());
734                                        let name = function.name.expect("function name should be present for complete tool call");
735                                        let arguments = function.arguments;
736                                        let Ok(arguments) = serde_json::from_str(&arguments) else {
737                                            tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
738                                            continue;
739                                        };
740
741                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
742                                    }
743                                }
744                            }
745
746                            if let Some(content) = &delta.reasoning_content {
747                                yield Ok(crate::streaming::RawStreamingChoice::Reasoning { reasoning: content.to_string()})
748                            }
749
750                            if let Some(content) = &delta.content {
751                                yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
752                            }
753
754                }
755
756
757                if let Some(usage) = data.usage {
758                    final_usage = usage.clone();
759                }
760            }
761        }
762
763        for (_, (id, name, arguments)) in calls {
764            let Ok(arguments) = serde_json::from_str(&arguments) else {
765                continue;
766            };
767
768            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
769        }
770
771        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
772            usage: final_usage.clone()
773        }))
774    });
775
776    Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
777}
778
779// ================================================================
780// DeepSeek Completion API
781// ================================================================
782
783/// `deepseek-chat` completion model
784pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
785/// `deepseek-reasoner` completion model
786pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
787
788// Tests
789#[cfg(test)]
790mod tests {
791
792    use super::*;
793
794    #[test]
795    fn test_deserialize_vec_choice() {
796        let data = r#"[{
797            "finish_reason": "stop",
798            "index": 0,
799            "logprobs": null,
800            "message":{"role":"assistant","content":"Hello, world!"}
801            }]"#;
802
803        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
804        assert_eq!(choices.len(), 1);
805        match &choices.first().unwrap().message {
806            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
807            _ => panic!("Expected assistant message"),
808        }
809    }
810
811    #[test]
812    fn test_deserialize_deepseek_response() {
813        let data = r#"{
814            "choices":[{
815                "finish_reason": "stop",
816                "index": 0,
817                "logprobs": null,
818                "message":{"role":"assistant","content":"Hello, world!"}
819            }],
820            "usage": {
821                "completion_tokens": 0,
822                "prompt_tokens": 0,
823                "prompt_cache_hit_tokens": 0,
824                "prompt_cache_miss_tokens": 0,
825                "total_tokens": 0
826            }
827        }"#;
828
829        let jd = &mut serde_json::Deserializer::from_str(data);
830        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
831        match result {
832            Ok(response) => match &response.choices.first().unwrap().message {
833                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
834                _ => panic!("Expected assistant message"),
835            },
836            Err(err) => {
837                panic!("Deserialization error at {}: {}", err.path(), err);
838            }
839        }
840    }
841
842    #[test]
843    fn test_deserialize_example_response() {
844        let data = r#"
845        {
846            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
847            "object": "chat.completion",
848            "created": 0,
849            "model": "deepseek-chat",
850            "choices": [
851                {
852                    "index": 0,
853                    "message": {
854                        "role": "assistant",
855                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
856                    },
857                    "logprobs": null,
858                    "finish_reason": "stop"
859                }
860            ],
861            "usage": {
862                "prompt_tokens": 13,
863                "completion_tokens": 32,
864                "total_tokens": 45,
865                "prompt_tokens_details": {
866                    "cached_tokens": 0
867                },
868                "prompt_cache_hit_tokens": 0,
869                "prompt_cache_miss_tokens": 13
870            },
871            "system_fingerprint": "fp_4b6881f2c5"
872        }
873        "#;
874        let jd = &mut serde_json::Deserializer::from_str(data);
875        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
876
877        match result {
878            Ok(response) => match &response.choices.first().unwrap().message {
879                Message::Assistant { content, .. } => assert_eq!(
880                    content,
881                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
882                ),
883                _ => panic!("Expected assistant message"),
884            },
885            Err(err) => {
886                panic!("Deserialization error at {}: {}", err.path(), err);
887            }
888        }
889    }
890
891    #[test]
892    fn test_serialize_deserialize_tool_call_message() {
893        let tool_call_choice_json = r#"
894            {
895              "finish_reason": "tool_calls",
896              "index": 0,
897              "logprobs": null,
898              "message": {
899                "content": "",
900                "role": "assistant",
901                "tool_calls": [
902                  {
903                    "function": {
904                      "arguments": "{\"x\":2,\"y\":5}",
905                      "name": "subtract"
906                    },
907                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
908                    "index": 0,
909                    "type": "function"
910                  }
911                ]
912              }
913            }
914        "#;
915
916        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
917
918        let expected_choice: Choice = Choice {
919            finish_reason: "tool_calls".to_string(),
920            index: 0,
921            logprobs: None,
922            message: Message::Assistant {
923                content: "".to_string(),
924                name: None,
925                tool_calls: vec![ToolCall {
926                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
927                    function: Function {
928                        name: "subtract".to_string(),
929                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
930                    },
931                    index: 0,
932                    r#type: ToolType::Function,
933                }],
934            },
935        };
936
937        assert_eq!(choice, expected_choice);
938    }
939}