rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use async_stream::stream;
13use futures::StreamExt;
14use reqwest_eventsource::{Event, RequestBuilderExt};
15use std::collections::HashMap;
16
17use crate::client::{
18    ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
19};
20use crate::completion::GetTokenUsage;
21use crate::json_utils::merge;
22use crate::message::Document;
23use crate::{
24    OneOrMany,
25    completion::{self, CompletionError, CompletionRequest},
26    impl_conversion_traits, json_utils, message,
27};
28use reqwest::Client as HttpClient;
29use serde::{Deserialize, Serialize};
30use serde_json::json;
31
32use super::openai::StreamingToolCall;
33
34// ================================================================
35// Main DeepSeek Client
36// ================================================================
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
38
39pub struct ClientBuilder<'a> {
40    api_key: &'a str,
41    base_url: &'a str,
42    http_client: Option<reqwest::Client>,
43}
44
45impl<'a> ClientBuilder<'a> {
46    pub fn new(api_key: &'a str) -> Self {
47        Self {
48            api_key,
49            base_url: DEEPSEEK_API_BASE_URL,
50            http_client: None,
51        }
52    }
53
54    pub fn base_url(mut self, base_url: &'a str) -> Self {
55        self.base_url = base_url;
56        self
57    }
58
59    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
60        self.http_client = Some(client);
61        self
62    }
63
64    pub fn build(self) -> Result<Client, ClientBuilderError> {
65        let http_client = if let Some(http_client) = self.http_client {
66            http_client
67        } else {
68            reqwest::Client::builder().build()?
69        };
70
71        Ok(Client {
72            base_url: self.base_url.to_string(),
73            api_key: self.api_key.to_string(),
74            http_client,
75        })
76    }
77}
78
79#[derive(Clone)]
80pub struct Client {
81    pub base_url: String,
82    api_key: String,
83    http_client: HttpClient,
84}
85
86impl std::fmt::Debug for Client {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("Client")
89            .field("base_url", &self.base_url)
90            .field("http_client", &self.http_client)
91            .field("api_key", &"<REDACTED>")
92            .finish()
93    }
94}
95
96impl Client {
97    /// Create a new DeepSeek client builder.
98    ///
99    /// # Example
100    /// ```
101    /// use rig::providers::deepseek::{ClientBuilder, self};
102    ///
103    /// // Initialize the DeepSeek client
104    /// let deepseek = Client::builder("your-deepseek-api-key")
105    ///    .build()
106    /// ```
107    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
108        ClientBuilder::new(api_key)
109    }
110
111    /// Create a new DeepSeek client. For more control, use the `builder` method.
112    ///
113    /// # Panics
114    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
115    pub fn new(api_key: &str) -> Self {
116        Self::builder(api_key)
117            .build()
118            .expect("DeepSeek client should build")
119    }
120
121    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
122        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
123        self.http_client.post(url).bearer_auth(&self.api_key)
124    }
125
126    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
127        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
128        self.http_client.get(url).bearer_auth(&self.api_key)
129    }
130}
131
132impl ProviderClient for Client {
133    // If you prefer the environment variable approach:
134    fn from_env() -> Self {
135        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
136        Self::new(&api_key)
137    }
138
139    fn from_val(input: crate::client::ProviderValue) -> Self {
140        let crate::client::ProviderValue::Simple(api_key) = input else {
141            panic!("Incorrect provider value type")
142        };
143        Self::new(&api_key)
144    }
145}
146
147impl CompletionClient for Client {
148    type CompletionModel = CompletionModel;
149
150    /// Creates a DeepSeek completion model with the given `model_name`.
151    fn completion_model(&self, model_name: &str) -> CompletionModel {
152        CompletionModel {
153            client: self.clone(),
154            model: model_name.to_string(),
155        }
156    }
157}
158
159impl VerifyClient for Client {
160    #[cfg_attr(feature = "worker", worker::send)]
161    async fn verify(&self) -> Result<(), VerifyError> {
162        let response = self.get("/user/balance").send().await?;
163        match response.status() {
164            reqwest::StatusCode::OK => Ok(()),
165            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
166            reqwest::StatusCode::INTERNAL_SERVER_ERROR
167            | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
168                Err(VerifyError::ProviderError(response.text().await?))
169            }
170            _ => {
171                response.error_for_status()?;
172                Ok(())
173            }
174        }
175    }
176}
177
178impl_conversion_traits!(
179    AsEmbeddings,
180    AsTranscription,
181    AsImageGeneration,
182    AsAudioGeneration for Client
183);
184
185#[derive(Debug, Deserialize)]
186struct ApiErrorResponse {
187    message: String,
188}
189
190#[derive(Debug, Deserialize)]
191#[serde(untagged)]
192enum ApiResponse<T> {
193    Ok(T),
194    Err(ApiErrorResponse),
195}
196
197impl From<ApiErrorResponse> for CompletionError {
198    fn from(err: ApiErrorResponse) -> Self {
199        CompletionError::ProviderError(err.message)
200    }
201}
202
203/// The response shape from the DeepSeek API
204#[derive(Clone, Debug, Serialize, Deserialize)]
205pub struct CompletionResponse {
206    // We'll match the JSON:
207    pub choices: Vec<Choice>,
208    pub usage: Usage,
209    // you may want other fields
210}
211
212#[derive(Clone, Debug, Serialize, Deserialize, Default)]
213pub struct Usage {
214    pub completion_tokens: u32,
215    pub prompt_tokens: u32,
216    pub prompt_cache_hit_tokens: u32,
217    pub prompt_cache_miss_tokens: u32,
218    pub total_tokens: u32,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub completion_tokens_details: Option<CompletionTokensDetails>,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub prompt_tokens_details: Option<PromptTokensDetails>,
223}
224
225impl Usage {
226    fn new() -> Self {
227        Self {
228            completion_tokens: 0,
229            prompt_tokens: 0,
230            prompt_cache_hit_tokens: 0,
231            prompt_cache_miss_tokens: 0,
232            total_tokens: 0,
233            completion_tokens_details: None,
234            prompt_tokens_details: None,
235        }
236    }
237}
238
239#[derive(Clone, Debug, Serialize, Deserialize, Default)]
240pub struct CompletionTokensDetails {
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub reasoning_tokens: Option<u32>,
243}
244
245#[derive(Clone, Debug, Serialize, Deserialize, Default)]
246pub struct PromptTokensDetails {
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub cached_tokens: Option<u32>,
249}
250
251#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
252pub struct Choice {
253    pub index: usize,
254    pub message: Message,
255    pub logprobs: Option<serde_json::Value>,
256    pub finish_reason: String,
257}
258
259#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
260#[serde(tag = "role", rename_all = "lowercase")]
261pub enum Message {
262    System {
263        content: String,
264        #[serde(skip_serializing_if = "Option::is_none")]
265        name: Option<String>,
266    },
267    User {
268        content: String,
269        #[serde(skip_serializing_if = "Option::is_none")]
270        name: Option<String>,
271    },
272    Assistant {
273        content: String,
274        #[serde(skip_serializing_if = "Option::is_none")]
275        name: Option<String>,
276        #[serde(
277            default,
278            deserialize_with = "json_utils::null_or_vec",
279            skip_serializing_if = "Vec::is_empty"
280        )]
281        tool_calls: Vec<ToolCall>,
282    },
283    #[serde(rename = "tool")]
284    ToolResult {
285        tool_call_id: String,
286        content: String,
287    },
288}
289
290impl Message {
291    pub fn system(content: &str) -> Self {
292        Message::System {
293            content: content.to_owned(),
294            name: None,
295        }
296    }
297}
298
299impl From<message::ToolResult> for Message {
300    fn from(tool_result: message::ToolResult) -> Self {
301        let content = match tool_result.content.first() {
302            message::ToolResultContent::Text(text) => text.text,
303            message::ToolResultContent::Image(_) => String::from("[Image]"),
304        };
305
306        Message::ToolResult {
307            tool_call_id: tool_result.id,
308            content,
309        }
310    }
311}
312
313impl From<message::ToolCall> for ToolCall {
314    fn from(tool_call: message::ToolCall) -> Self {
315        Self {
316            id: tool_call.id,
317            // TODO: update index when we have it
318            index: 0,
319            r#type: ToolType::Function,
320            function: Function {
321                name: tool_call.function.name,
322                arguments: tool_call.function.arguments,
323            },
324        }
325    }
326}
327
328impl TryFrom<message::Message> for Vec<Message> {
329    type Error = message::MessageError;
330
331    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
332        match message {
333            message::Message::User { content } => {
334                // extract tool results
335                let mut messages = vec![];
336
337                let tool_results = content
338                    .clone()
339                    .into_iter()
340                    .filter_map(|content| match content {
341                        message::UserContent::ToolResult(tool_result) => {
342                            Some(Message::from(tool_result))
343                        }
344                        _ => None,
345                    })
346                    .collect::<Vec<_>>();
347
348                messages.extend(tool_results);
349
350                // extract text results
351                let text_messages = content
352                    .into_iter()
353                    .filter_map(|content| match content {
354                        message::UserContent::Text(text) => Some(Message::User {
355                            content: text.text,
356                            name: None,
357                        }),
358                        message::UserContent::Document(Document { data, .. }) => {
359                            Some(Message::User {
360                                content: data,
361                                name: None,
362                            })
363                        }
364                        _ => None,
365                    })
366                    .collect::<Vec<_>>();
367                messages.extend(text_messages);
368
369                Ok(messages)
370            }
371            message::Message::Assistant { content, .. } => {
372                let mut messages: Vec<Message> = vec![];
373
374                // extract tool calls
375                let tool_calls = content
376                    .clone()
377                    .into_iter()
378                    .filter_map(|content| match content {
379                        message::AssistantContent::ToolCall(tool_call) => {
380                            Some(ToolCall::from(tool_call))
381                        }
382                        _ => None,
383                    })
384                    .collect::<Vec<_>>();
385
386                // if we have tool calls, we add a new Assistant message with them
387                if !tool_calls.is_empty() {
388                    messages.push(Message::Assistant {
389                        content: "".to_string(),
390                        name: None,
391                        tool_calls,
392                    });
393                }
394
395                // extract text
396                let text_content = content
397                    .into_iter()
398                    .filter_map(|content| match content {
399                        message::AssistantContent::Text(text) => Some(Message::Assistant {
400                            content: text.text,
401                            name: None,
402                            tool_calls: vec![],
403                        }),
404                        _ => None,
405                    })
406                    .collect::<Vec<_>>();
407
408                messages.extend(text_content);
409
410                Ok(messages)
411            }
412        }
413    }
414}
415
416#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
417pub struct ToolCall {
418    pub id: String,
419    pub index: usize,
420    #[serde(default)]
421    pub r#type: ToolType,
422    pub function: Function,
423}
424
425#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
426pub struct Function {
427    pub name: String,
428    #[serde(with = "json_utils::stringified_json")]
429    pub arguments: serde_json::Value,
430}
431
432#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
433#[serde(rename_all = "lowercase")]
434pub enum ToolType {
435    #[default]
436    Function,
437}
438
439#[derive(Clone, Debug, Deserialize, Serialize)]
440pub struct ToolDefinition {
441    pub r#type: String,
442    pub function: completion::ToolDefinition,
443}
444
445impl From<crate::completion::ToolDefinition> for ToolDefinition {
446    fn from(tool: crate::completion::ToolDefinition) -> Self {
447        Self {
448            r#type: "function".into(),
449            function: tool,
450        }
451    }
452}
453
454impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
455    type Error = CompletionError;
456
457    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
458        let choice = response.choices.first().ok_or_else(|| {
459            CompletionError::ResponseError("Response contained no choices".to_owned())
460        })?;
461        let content = match &choice.message {
462            Message::Assistant {
463                content,
464                tool_calls,
465                ..
466            } => {
467                let mut content = if content.trim().is_empty() {
468                    vec![]
469                } else {
470                    vec![completion::AssistantContent::text(content)]
471                };
472
473                content.extend(
474                    tool_calls
475                        .iter()
476                        .map(|call| {
477                            completion::AssistantContent::tool_call(
478                                &call.id,
479                                &call.function.name,
480                                call.function.arguments.clone(),
481                            )
482                        })
483                        .collect::<Vec<_>>(),
484                );
485                Ok(content)
486            }
487            _ => Err(CompletionError::ResponseError(
488                "Response did not contain a valid message or tool call".into(),
489            )),
490        }?;
491
492        let choice = OneOrMany::many(content).map_err(|_| {
493            CompletionError::ResponseError(
494                "Response contained no message or tool call (empty)".to_owned(),
495            )
496        })?;
497
498        let usage = completion::Usage {
499            input_tokens: response.usage.prompt_tokens as u64,
500            output_tokens: response.usage.completion_tokens as u64,
501            total_tokens: response.usage.total_tokens as u64,
502        };
503
504        Ok(completion::CompletionResponse {
505            choice,
506            usage,
507            raw_response: response,
508        })
509    }
510}
511
512/// The struct implementing the `CompletionModel` trait
513#[derive(Clone)]
514pub struct CompletionModel {
515    pub client: Client,
516    pub model: String,
517}
518
519impl CompletionModel {
520    fn create_completion_request(
521        &self,
522        completion_request: CompletionRequest,
523    ) -> Result<serde_json::Value, CompletionError> {
524        // Build up the order of messages (context, chat_history, prompt)
525        let mut partial_history = vec![];
526
527        if let Some(docs) = completion_request.normalized_documents() {
528            partial_history.push(docs);
529        }
530
531        partial_history.extend(completion_request.chat_history);
532
533        // Initialize full history with preamble (or empty if non-existent)
534        let mut full_history: Vec<Message> = completion_request
535            .preamble
536            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
537
538        // Convert and extend the rest of the history
539        full_history.extend(
540            partial_history
541                .into_iter()
542                .map(message::Message::try_into)
543                .collect::<Result<Vec<Vec<Message>>, _>>()?
544                .into_iter()
545                .flatten()
546                .collect::<Vec<_>>(),
547        );
548
549        let request = if completion_request.tools.is_empty() {
550            json!({
551                "model": self.model,
552                "messages": full_history,
553                "temperature": completion_request.temperature,
554            })
555        } else {
556            json!({
557                "model": self.model,
558                "messages": full_history,
559                "temperature": completion_request.temperature,
560                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
561                "tool_choice": "auto",
562            })
563        };
564
565        let request = if let Some(params) = completion_request.additional_params {
566            json_utils::merge(request, params)
567        } else {
568            request
569        };
570
571        Ok(request)
572    }
573}
574
575impl completion::CompletionModel for CompletionModel {
576    type Response = CompletionResponse;
577    type StreamingResponse = StreamingCompletionResponse;
578
579    #[cfg_attr(feature = "worker", worker::send)]
580    async fn completion(
581        &self,
582        completion_request: CompletionRequest,
583    ) -> Result<
584        completion::CompletionResponse<CompletionResponse>,
585        crate::completion::CompletionError,
586    > {
587        let request = self.create_completion_request(completion_request)?;
588
589        tracing::debug!("DeepSeek completion request: {request:?}");
590
591        let response = self
592            .client
593            .post("/chat/completions")
594            .json(&request)
595            .send()
596            .await?;
597
598        if response.status().is_success() {
599            let t = response.text().await?;
600            tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
601
602            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
603                ApiResponse::Ok(response) => response.try_into(),
604                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
605            }
606        } else {
607            Err(CompletionError::ProviderError(response.text().await?))
608        }
609    }
610
611    #[cfg_attr(feature = "worker", worker::send)]
612    async fn stream(
613        &self,
614        completion_request: CompletionRequest,
615    ) -> Result<
616        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
617        CompletionError,
618    > {
619        let mut request = self.create_completion_request(completion_request)?;
620
621        request = merge(
622            request,
623            json!({"stream": true, "stream_options": {"include_usage": true}}),
624        );
625
626        let builder = self.client.post("/chat/completions").json(&request);
627        send_compatible_streaming_request(builder).await
628    }
629}
630
631#[derive(Deserialize, Debug)]
632pub struct StreamingDelta {
633    #[serde(default)]
634    content: Option<String>,
635    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
636    tool_calls: Vec<StreamingToolCall>,
637    reasoning_content: Option<String>,
638}
639
640#[derive(Deserialize, Debug)]
641struct StreamingChoice {
642    delta: StreamingDelta,
643}
644
645#[derive(Deserialize, Debug)]
646struct StreamingCompletionChunk {
647    choices: Vec<StreamingChoice>,
648    usage: Option<Usage>,
649}
650
651#[derive(Clone, Deserialize, Serialize, Debug)]
652pub struct StreamingCompletionResponse {
653    pub usage: Usage,
654}
655
656impl GetTokenUsage for StreamingCompletionResponse {
657    fn token_usage(&self) -> Option<crate::completion::Usage> {
658        let mut usage = crate::completion::Usage::new();
659        usage.input_tokens = self.usage.prompt_tokens as u64;
660        usage.output_tokens = self.usage.completion_tokens as u64;
661        usage.total_tokens = self.usage.total_tokens as u64;
662
663        Some(usage)
664    }
665}
666
667pub async fn send_compatible_streaming_request(
668    request_builder: reqwest::RequestBuilder,
669) -> Result<
670    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
671    CompletionError,
672> {
673    let mut event_source = request_builder
674        .eventsource()
675        .expect("Cloning request must succeed");
676
677    let stream = Box::pin(stream! {
678        let mut final_usage = Usage::new();
679        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
680
681        while let Some(event_result) = event_source.next().await {
682            match event_result {
683                Ok(Event::Open) => {
684                    tracing::trace!("SSE connection opened");
685                    continue;
686                }
687                Ok(Event::Message(message)) => {
688                    if message.data.trim().is_empty() || message.data == "[DONE]" {
689                        continue;
690                    }
691
692                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
693                    let Ok(data) = parsed else {
694                        let err = parsed.unwrap_err();
695                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
696                        continue;
697                    };
698
699                    if let Some(choice) = data.choices.first() {
700                        let delta = &choice.delta;
701
702                        if !delta.tool_calls.is_empty() {
703                            for tool_call in &delta.tool_calls {
704                                let function = &tool_call.function;
705
706                                // Start of tool call
707                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
708                                    && function.arguments.is_empty()
709                                {
710                                    let id = tool_call.id.clone().unwrap_or_default();
711                                    let name = function.name.clone().unwrap();
712                                    calls.insert(tool_call.index, (id, name, String::new()));
713                                }
714                                // Continuation of tool call
715                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
716                                    && !function.arguments.is_empty()
717                                {
718                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
719                                        let combined = format!("{}{}", existing_args, function.arguments);
720                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
721                                    } else {
722                                        tracing::debug!("Partial tool call received but tool call was never started.");
723                                    }
724                                }
725                                // Complete tool call
726                                else {
727                                    let id = tool_call.id.clone().unwrap_or_default();
728                                    let name = function.name.clone().unwrap_or_default();
729                                    let arguments_str = function.arguments.clone();
730
731                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
732                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
733                                        continue;
734                                    };
735
736                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
737                                        id,
738                                        name,
739                                        arguments: arguments_json,
740                                        call_id: None,
741                                    });
742                                }
743                            }
744                        }
745
746                        // DeepSeek-specific reasoning stream
747                        if let Some(content) = &delta.reasoning_content {
748                            yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
749                                reasoning: content.to_string(),
750                                id: None,
751                            });
752                        }
753
754                        if let Some(content) = &delta.content {
755                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
756                        }
757                    }
758
759                    if let Some(usage) = data.usage {
760                        final_usage = usage.clone();
761                    }
762                }
763                Err(reqwest_eventsource::Error::StreamEnded) => {
764                    break;
765                }
766                Err(err) => {
767                    tracing::error!(?err, "SSE error");
768                    yield Err(CompletionError::ResponseError(err.to_string()));
769                    break;
770                }
771            }
772        }
773
774        // Flush accumulated tool calls
775        for (_, (id, name, arguments)) in calls {
776            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
777                continue;
778            };
779            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
780                id,
781                name,
782                arguments: arguments_json,
783                call_id: None,
784            });
785        }
786
787        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
788            StreamingCompletionResponse { usage: final_usage.clone() }
789        ));
790    });
791
792    Ok(crate::streaming::StreamingCompletionResponse::stream(
793        stream,
794    ))
795}
796
797// ================================================================
798// DeepSeek Completion API
799// ================================================================
800
801/// `deepseek-chat` completion model
802pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
803/// `deepseek-reasoner` completion model
804pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
805
806// Tests
807#[cfg(test)]
808mod tests {
809
810    use super::*;
811
812    #[test]
813    fn test_deserialize_vec_choice() {
814        let data = r#"[{
815            "finish_reason": "stop",
816            "index": 0,
817            "logprobs": null,
818            "message":{"role":"assistant","content":"Hello, world!"}
819            }]"#;
820
821        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
822        assert_eq!(choices.len(), 1);
823        match &choices.first().unwrap().message {
824            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
825            _ => panic!("Expected assistant message"),
826        }
827    }
828
829    #[test]
830    fn test_deserialize_deepseek_response() {
831        let data = r#"{
832            "choices":[{
833                "finish_reason": "stop",
834                "index": 0,
835                "logprobs": null,
836                "message":{"role":"assistant","content":"Hello, world!"}
837            }],
838            "usage": {
839                "completion_tokens": 0,
840                "prompt_tokens": 0,
841                "prompt_cache_hit_tokens": 0,
842                "prompt_cache_miss_tokens": 0,
843                "total_tokens": 0
844            }
845        }"#;
846
847        let jd = &mut serde_json::Deserializer::from_str(data);
848        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
849        match result {
850            Ok(response) => match &response.choices.first().unwrap().message {
851                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
852                _ => panic!("Expected assistant message"),
853            },
854            Err(err) => {
855                panic!("Deserialization error at {}: {}", err.path(), err);
856            }
857        }
858    }
859
860    #[test]
861    fn test_deserialize_example_response() {
862        let data = r#"
863        {
864            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
865            "object": "chat.completion",
866            "created": 0,
867            "model": "deepseek-chat",
868            "choices": [
869                {
870                    "index": 0,
871                    "message": {
872                        "role": "assistant",
873                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
874                    },
875                    "logprobs": null,
876                    "finish_reason": "stop"
877                }
878            ],
879            "usage": {
880                "prompt_tokens": 13,
881                "completion_tokens": 32,
882                "total_tokens": 45,
883                "prompt_tokens_details": {
884                    "cached_tokens": 0
885                },
886                "prompt_cache_hit_tokens": 0,
887                "prompt_cache_miss_tokens": 13
888            },
889            "system_fingerprint": "fp_4b6881f2c5"
890        }
891        "#;
892        let jd = &mut serde_json::Deserializer::from_str(data);
893        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
894
895        match result {
896            Ok(response) => match &response.choices.first().unwrap().message {
897                Message::Assistant { content, .. } => assert_eq!(
898                    content,
899                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
900                ),
901                _ => panic!("Expected assistant message"),
902            },
903            Err(err) => {
904                panic!("Deserialization error at {}: {}", err.path(), err);
905            }
906        }
907    }
908
909    #[test]
910    fn test_serialize_deserialize_tool_call_message() {
911        let tool_call_choice_json = r#"
912            {
913              "finish_reason": "tool_calls",
914              "index": 0,
915              "logprobs": null,
916              "message": {
917                "content": "",
918                "role": "assistant",
919                "tool_calls": [
920                  {
921                    "function": {
922                      "arguments": "{\"x\":2,\"y\":5}",
923                      "name": "subtract"
924                    },
925                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
926                    "index": 0,
927                    "type": "function"
928                  }
929                ]
930              }
931            }
932        "#;
933
934        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
935
936        let expected_choice: Choice = Choice {
937            finish_reason: "tool_calls".to_string(),
938            index: 0,
939            logprobs: None,
940            message: Message::Assistant {
941                content: "".to_string(),
942                name: None,
943                tool_calls: vec![ToolCall {
944                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
945                    function: Function {
946                        name: "subtract".to_string(),
947                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
948                    },
949                    index: 0,
950                    r#type: ToolType::Function,
951                }],
952            },
953        };
954
955        assert_eq!(choice, expected_choice);
956    }
957}