rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use futures::StreamExt;
13use std::collections::HashMap;
14
15use crate::client::{
16    ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
17};
18use crate::completion::GetTokenUsage;
19use crate::json_utils::merge;
20use crate::message::Document;
21use crate::{
22    OneOrMany,
23    completion::{self, CompletionError, CompletionRequest},
24    impl_conversion_traits, json_utils, message,
25};
26use reqwest::Client as HttpClient;
27use serde::{Deserialize, Serialize};
28use serde_json::json;
29
30use super::openai::StreamingToolCall;
31
32// ================================================================
33// Main DeepSeek Client
34// ================================================================
35const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
36
37pub struct ClientBuilder<'a> {
38    api_key: &'a str,
39    base_url: &'a str,
40    http_client: Option<reqwest::Client>,
41}
42
43impl<'a> ClientBuilder<'a> {
44    pub fn new(api_key: &'a str) -> Self {
45        Self {
46            api_key,
47            base_url: DEEPSEEK_API_BASE_URL,
48            http_client: None,
49        }
50    }
51
52    pub fn base_url(mut self, base_url: &'a str) -> Self {
53        self.base_url = base_url;
54        self
55    }
56
57    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
58        self.http_client = Some(client);
59        self
60    }
61
62    pub fn build(self) -> Result<Client, ClientBuilderError> {
63        let http_client = if let Some(http_client) = self.http_client {
64            http_client
65        } else {
66            reqwest::Client::builder().build()?
67        };
68
69        Ok(Client {
70            base_url: self.base_url.to_string(),
71            api_key: self.api_key.to_string(),
72            http_client,
73        })
74    }
75}
76
77#[derive(Clone)]
78pub struct Client {
79    pub base_url: String,
80    api_key: String,
81    http_client: HttpClient,
82}
83
84impl std::fmt::Debug for Client {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_struct("Client")
87            .field("base_url", &self.base_url)
88            .field("http_client", &self.http_client)
89            .field("api_key", &"<REDACTED>")
90            .finish()
91    }
92}
93
94impl Client {
95    /// Create a new DeepSeek client builder.
96    ///
97    /// # Example
98    /// ```
99    /// use rig::providers::deepseek::{ClientBuilder, self};
100    ///
101    /// // Initialize the DeepSeek client
102    /// let deepseek = Client::builder("your-deepseek-api-key")
103    ///    .build()
104    /// ```
105    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
106        ClientBuilder::new(api_key)
107    }
108
109    /// Create a new DeepSeek client. For more control, use the `builder` method.
110    ///
111    /// # Panics
112    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
113    pub fn new(api_key: &str) -> Self {
114        Self::builder(api_key)
115            .build()
116            .expect("DeepSeek client should build")
117    }
118
119    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
120        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
121        self.http_client.post(url).bearer_auth(&self.api_key)
122    }
123
124    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
125        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
126        self.http_client.get(url).bearer_auth(&self.api_key)
127    }
128}
129
130impl ProviderClient for Client {
131    // If you prefer the environment variable approach:
132    fn from_env() -> Self {
133        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
134        Self::new(&api_key)
135    }
136
137    fn from_val(input: crate::client::ProviderValue) -> Self {
138        let crate::client::ProviderValue::Simple(api_key) = input else {
139            panic!("Incorrect provider value type")
140        };
141        Self::new(&api_key)
142    }
143}
144
145impl CompletionClient for Client {
146    type CompletionModel = CompletionModel;
147
148    /// Creates a DeepSeek completion model with the given `model_name`.
149    fn completion_model(&self, model_name: &str) -> CompletionModel {
150        CompletionModel {
151            client: self.clone(),
152            model: model_name.to_string(),
153        }
154    }
155}
156
157impl VerifyClient for Client {
158    #[cfg_attr(feature = "worker", worker::send)]
159    async fn verify(&self) -> Result<(), VerifyError> {
160        let response = self.get("/user/balance").send().await?;
161        match response.status() {
162            reqwest::StatusCode::OK => Ok(()),
163            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
164            reqwest::StatusCode::INTERNAL_SERVER_ERROR
165            | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
166                Err(VerifyError::ProviderError(response.text().await?))
167            }
168            _ => {
169                response.error_for_status()?;
170                Ok(())
171            }
172        }
173    }
174}
175
176impl_conversion_traits!(
177    AsEmbeddings,
178    AsTranscription,
179    AsImageGeneration,
180    AsAudioGeneration for Client
181);
182
183#[derive(Debug, Deserialize)]
184struct ApiErrorResponse {
185    message: String,
186}
187
188#[derive(Debug, Deserialize)]
189#[serde(untagged)]
190enum ApiResponse<T> {
191    Ok(T),
192    Err(ApiErrorResponse),
193}
194
195impl From<ApiErrorResponse> for CompletionError {
196    fn from(err: ApiErrorResponse) -> Self {
197        CompletionError::ProviderError(err.message)
198    }
199}
200
201/// The response shape from the DeepSeek API
202#[derive(Clone, Debug, Serialize, Deserialize)]
203pub struct CompletionResponse {
204    // We'll match the JSON:
205    pub choices: Vec<Choice>,
206    pub usage: Usage,
207    // you may want other fields
208}
209
210#[derive(Clone, Debug, Serialize, Deserialize, Default)]
211pub struct Usage {
212    pub completion_tokens: u32,
213    pub prompt_tokens: u32,
214    pub prompt_cache_hit_tokens: u32,
215    pub prompt_cache_miss_tokens: u32,
216    pub total_tokens: u32,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub completion_tokens_details: Option<CompletionTokensDetails>,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub prompt_tokens_details: Option<PromptTokensDetails>,
221}
222
223impl Usage {
224    fn new() -> Self {
225        Self {
226            completion_tokens: 0,
227            prompt_tokens: 0,
228            prompt_cache_hit_tokens: 0,
229            prompt_cache_miss_tokens: 0,
230            total_tokens: 0,
231            completion_tokens_details: None,
232            prompt_tokens_details: None,
233        }
234    }
235}
236
237#[derive(Clone, Debug, Serialize, Deserialize, Default)]
238pub struct CompletionTokensDetails {
239    #[serde(skip_serializing_if = "Option::is_none")]
240    pub reasoning_tokens: Option<u32>,
241}
242
243#[derive(Clone, Debug, Serialize, Deserialize, Default)]
244pub struct PromptTokensDetails {
245    #[serde(skip_serializing_if = "Option::is_none")]
246    pub cached_tokens: Option<u32>,
247}
248
249#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
250pub struct Choice {
251    pub index: usize,
252    pub message: Message,
253    pub logprobs: Option<serde_json::Value>,
254    pub finish_reason: String,
255}
256
257#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
258#[serde(tag = "role", rename_all = "lowercase")]
259pub enum Message {
260    System {
261        content: String,
262        #[serde(skip_serializing_if = "Option::is_none")]
263        name: Option<String>,
264    },
265    User {
266        content: String,
267        #[serde(skip_serializing_if = "Option::is_none")]
268        name: Option<String>,
269    },
270    Assistant {
271        content: String,
272        #[serde(skip_serializing_if = "Option::is_none")]
273        name: Option<String>,
274        #[serde(
275            default,
276            deserialize_with = "json_utils::null_or_vec",
277            skip_serializing_if = "Vec::is_empty"
278        )]
279        tool_calls: Vec<ToolCall>,
280    },
281    #[serde(rename = "tool")]
282    ToolResult {
283        tool_call_id: String,
284        content: String,
285    },
286}
287
288impl Message {
289    pub fn system(content: &str) -> Self {
290        Message::System {
291            content: content.to_owned(),
292            name: None,
293        }
294    }
295}
296
297impl From<message::ToolResult> for Message {
298    fn from(tool_result: message::ToolResult) -> Self {
299        let content = match tool_result.content.first() {
300            message::ToolResultContent::Text(text) => text.text,
301            message::ToolResultContent::Image(_) => String::from("[Image]"),
302        };
303
304        Message::ToolResult {
305            tool_call_id: tool_result.id,
306            content,
307        }
308    }
309}
310
311impl From<message::ToolCall> for ToolCall {
312    fn from(tool_call: message::ToolCall) -> Self {
313        Self {
314            id: tool_call.id,
315            // TODO: update index when we have it
316            index: 0,
317            r#type: ToolType::Function,
318            function: Function {
319                name: tool_call.function.name,
320                arguments: tool_call.function.arguments,
321            },
322        }
323    }
324}
325
326impl TryFrom<message::Message> for Vec<Message> {
327    type Error = message::MessageError;
328
329    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
330        match message {
331            message::Message::User { content } => {
332                // extract tool results
333                let mut messages = vec![];
334
335                let tool_results = content
336                    .clone()
337                    .into_iter()
338                    .filter_map(|content| match content {
339                        message::UserContent::ToolResult(tool_result) => {
340                            Some(Message::from(tool_result))
341                        }
342                        _ => None,
343                    })
344                    .collect::<Vec<_>>();
345
346                messages.extend(tool_results);
347
348                // extract text results
349                let text_messages = content
350                    .into_iter()
351                    .filter_map(|content| match content {
352                        message::UserContent::Text(text) => Some(Message::User {
353                            content: text.text,
354                            name: None,
355                        }),
356                        message::UserContent::Document(Document { data, .. }) => {
357                            Some(Message::User {
358                                content: data,
359                                name: None,
360                            })
361                        }
362                        _ => None,
363                    })
364                    .collect::<Vec<_>>();
365                messages.extend(text_messages);
366
367                Ok(messages)
368            }
369            message::Message::Assistant { content, .. } => {
370                let mut messages: Vec<Message> = vec![];
371
372                // extract tool calls
373                let tool_calls = content
374                    .clone()
375                    .into_iter()
376                    .filter_map(|content| match content {
377                        message::AssistantContent::ToolCall(tool_call) => {
378                            Some(ToolCall::from(tool_call))
379                        }
380                        _ => None,
381                    })
382                    .collect::<Vec<_>>();
383
384                // if we have tool calls, we add a new Assistant message with them
385                if !tool_calls.is_empty() {
386                    messages.push(Message::Assistant {
387                        content: "".to_string(),
388                        name: None,
389                        tool_calls,
390                    });
391                }
392
393                // extract text
394                let text_content = content
395                    .into_iter()
396                    .filter_map(|content| match content {
397                        message::AssistantContent::Text(text) => Some(Message::Assistant {
398                            content: text.text,
399                            name: None,
400                            tool_calls: vec![],
401                        }),
402                        _ => None,
403                    })
404                    .collect::<Vec<_>>();
405
406                messages.extend(text_content);
407
408                Ok(messages)
409            }
410        }
411    }
412}
413
414#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
415pub struct ToolCall {
416    pub id: String,
417    pub index: usize,
418    #[serde(default)]
419    pub r#type: ToolType,
420    pub function: Function,
421}
422
423#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
424pub struct Function {
425    pub name: String,
426    #[serde(with = "json_utils::stringified_json")]
427    pub arguments: serde_json::Value,
428}
429
430#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
431#[serde(rename_all = "lowercase")]
432pub enum ToolType {
433    #[default]
434    Function,
435}
436
437#[derive(Clone, Debug, Deserialize, Serialize)]
438pub struct ToolDefinition {
439    pub r#type: String,
440    pub function: completion::ToolDefinition,
441}
442
443impl From<crate::completion::ToolDefinition> for ToolDefinition {
444    fn from(tool: crate::completion::ToolDefinition) -> Self {
445        Self {
446            r#type: "function".into(),
447            function: tool,
448        }
449    }
450}
451
452impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
453    type Error = CompletionError;
454
455    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
456        let choice = response.choices.first().ok_or_else(|| {
457            CompletionError::ResponseError("Response contained no choices".to_owned())
458        })?;
459        let content = match &choice.message {
460            Message::Assistant {
461                content,
462                tool_calls,
463                ..
464            } => {
465                let mut content = if content.trim().is_empty() {
466                    vec![]
467                } else {
468                    vec![completion::AssistantContent::text(content)]
469                };
470
471                content.extend(
472                    tool_calls
473                        .iter()
474                        .map(|call| {
475                            completion::AssistantContent::tool_call(
476                                &call.id,
477                                &call.function.name,
478                                call.function.arguments.clone(),
479                            )
480                        })
481                        .collect::<Vec<_>>(),
482                );
483                Ok(content)
484            }
485            _ => Err(CompletionError::ResponseError(
486                "Response did not contain a valid message or tool call".into(),
487            )),
488        }?;
489
490        let choice = OneOrMany::many(content).map_err(|_| {
491            CompletionError::ResponseError(
492                "Response contained no message or tool call (empty)".to_owned(),
493            )
494        })?;
495
496        let usage = completion::Usage {
497            input_tokens: response.usage.prompt_tokens as u64,
498            output_tokens: response.usage.completion_tokens as u64,
499            total_tokens: response.usage.total_tokens as u64,
500        };
501
502        Ok(completion::CompletionResponse {
503            choice,
504            usage,
505            raw_response: response,
506        })
507    }
508}
509
510/// The struct implementing the `CompletionModel` trait
511#[derive(Clone)]
512pub struct CompletionModel {
513    pub client: Client,
514    pub model: String,
515}
516
517impl CompletionModel {
518    fn create_completion_request(
519        &self,
520        completion_request: CompletionRequest,
521    ) -> Result<serde_json::Value, CompletionError> {
522        // Build up the order of messages (context, chat_history, prompt)
523        let mut partial_history = vec![];
524
525        if let Some(docs) = completion_request.normalized_documents() {
526            partial_history.push(docs);
527        }
528
529        partial_history.extend(completion_request.chat_history);
530
531        // Initialize full history with preamble (or empty if non-existent)
532        let mut full_history: Vec<Message> = completion_request
533            .preamble
534            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
535
536        // Convert and extend the rest of the history
537        full_history.extend(
538            partial_history
539                .into_iter()
540                .map(message::Message::try_into)
541                .collect::<Result<Vec<Vec<Message>>, _>>()?
542                .into_iter()
543                .flatten()
544                .collect::<Vec<_>>(),
545        );
546
547        let request = if completion_request.tools.is_empty() {
548            json!({
549                "model": self.model,
550                "messages": full_history,
551                "temperature": completion_request.temperature,
552            })
553        } else {
554            json!({
555                "model": self.model,
556                "messages": full_history,
557                "temperature": completion_request.temperature,
558                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
559                "tool_choice": "auto",
560            })
561        };
562
563        let request = if let Some(params) = completion_request.additional_params {
564            json_utils::merge(request, params)
565        } else {
566            request
567        };
568
569        Ok(request)
570    }
571}
572
573impl completion::CompletionModel for CompletionModel {
574    type Response = CompletionResponse;
575    type StreamingResponse = StreamingCompletionResponse;
576
577    #[cfg_attr(feature = "worker", worker::send)]
578    async fn completion(
579        &self,
580        completion_request: CompletionRequest,
581    ) -> Result<
582        completion::CompletionResponse<CompletionResponse>,
583        crate::completion::CompletionError,
584    > {
585        let request = self.create_completion_request(completion_request)?;
586
587        tracing::debug!("DeepSeek completion request: {request:?}");
588
589        let response = self
590            .client
591            .post("/chat/completions")
592            .json(&request)
593            .send()
594            .await?;
595
596        if response.status().is_success() {
597            let t = response.text().await?;
598            tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
599
600            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
601                ApiResponse::Ok(response) => response.try_into(),
602                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
603            }
604        } else {
605            Err(CompletionError::ProviderError(response.text().await?))
606        }
607    }
608
609    #[cfg_attr(feature = "worker", worker::send)]
610    async fn stream(
611        &self,
612        completion_request: CompletionRequest,
613    ) -> Result<
614        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
615        CompletionError,
616    > {
617        let mut request = self.create_completion_request(completion_request)?;
618
619        request = merge(
620            request,
621            json!({"stream": true, "stream_options": {"include_usage": true}}),
622        );
623
624        let builder = self.client.post("/chat/completions").json(&request);
625        send_compatible_streaming_request(builder).await
626    }
627}
628
629#[derive(Deserialize, Debug)]
630pub struct StreamingDelta {
631    #[serde(default)]
632    content: Option<String>,
633    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
634    tool_calls: Vec<StreamingToolCall>,
635    reasoning_content: Option<String>,
636}
637
638#[derive(Deserialize, Debug)]
639struct StreamingChoice {
640    delta: StreamingDelta,
641}
642
643#[derive(Deserialize, Debug)]
644struct StreamingCompletionChunk {
645    choices: Vec<StreamingChoice>,
646    usage: Option<Usage>,
647}
648
649#[derive(Clone, Deserialize, Serialize, Debug)]
650pub struct StreamingCompletionResponse {
651    pub usage: Usage,
652}
653
654impl GetTokenUsage for StreamingCompletionResponse {
655    fn token_usage(&self) -> Option<crate::completion::Usage> {
656        let mut usage = crate::completion::Usage::new();
657        usage.input_tokens = self.usage.prompt_tokens as u64;
658        usage.output_tokens = self.usage.completion_tokens as u64;
659        usage.total_tokens = self.usage.total_tokens as u64;
660
661        Some(usage)
662    }
663}
664
665pub async fn send_compatible_streaming_request(
666    request_builder: reqwest::RequestBuilder,
667) -> Result<
668    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
669    CompletionError,
670> {
671    let response = request_builder.send().await?;
672
673    if !response.status().is_success() {
674        return Err(CompletionError::ProviderError(format!(
675            "{}: {}",
676            response.status(),
677            response.text().await?
678        )));
679    }
680
681    // Handle OpenAI Compatible SSE chunks
682    let inner = Box::pin(async_stream::stream! {
683        let mut stream = response.bytes_stream();
684
685        let mut final_usage = Usage::new();
686        let mut partial_data = None;
687        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
688
689        while let Some(chunk_result) = stream.next().await {
690            let chunk = match chunk_result {
691                Ok(c) => c,
692                Err(e) => {
693                    yield Err(CompletionError::from(e));
694                    break;
695                }
696            };
697
698            let text = match String::from_utf8(chunk.to_vec()) {
699                Ok(t) => t,
700                Err(e) => {
701                    yield Err(CompletionError::ResponseError(e.to_string()));
702                    break;
703                }
704            };
705
706
707            for line in text.lines() {
708                let mut line = line.to_string();
709
710                // If there was a remaining part, concat with current line
711                if partial_data.is_some() {
712                    line = format!("{}{}", partial_data.unwrap(), line);
713                    partial_data = None;
714                }
715                // Otherwise full data line
716                else {
717                    let Some(data) = line.strip_prefix("data:") else {
718                        continue;
719                    };
720
721                    let data = data.trim_start();
722
723                    // Partial data, split somewhere in the middle
724                    if !line.ends_with("}") {
725                        partial_data = Some(data.to_string());
726                    } else {
727                        line = data.to_string();
728                    }
729                }
730
731                let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
732
733                let Ok(data) = data else {
734                    let err = data.unwrap_err();
735                    tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
736                    continue;
737                };
738
739
740                if let Some(choice) = data.choices.first() {
741                    let delta = &choice.delta;
742
743
744                            if !delta.tool_calls.is_empty() {
745                                for tool_call in &delta.tool_calls {
746                                    let function = tool_call.function.clone();
747                                    // Start of tool call
748                                    // name: Some(String)
749                                    // arguments: None
750                                    if function.name.is_some() && function.arguments.is_empty() {
751                                        let id = tool_call.id.clone().unwrap_or("".to_string());
752
753                                        calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
754                                    }
755                                    // Part of tool call
756                                    // name: None or Empty String
757                                    // arguments: Some(String)
758                                    else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
759                                        let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
760                                            tracing::debug!("Partial tool call received but tool call was never started.");
761                                            continue;
762                                        };
763
764                                        let new_arguments = &function.arguments;
765                                        let arguments = format!("{arguments}{new_arguments}");
766
767                                        calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
768                                    }
769                                    // Entire tool call
770                                    else {
771                                        let id = tool_call.id.clone().unwrap_or("".to_string());
772                                        let name = function.name.expect("function name should be present for complete tool call");
773                                        let arguments = function.arguments;
774                                        let Ok(arguments) = serde_json::from_str(&arguments) else {
775                                            tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
776                                            continue;
777                                        };
778
779                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
780                                    }
781                                }
782                            }
783
784                            if let Some(content) = &delta.reasoning_content {
785                                yield Ok(crate::streaming::RawStreamingChoice::Reasoning { reasoning: content.to_string(), id: None})
786                            }
787
788                            if let Some(content) = &delta.content {
789                                yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
790                            }
791
792                }
793
794
795                if let Some(usage) = data.usage {
796                    final_usage = usage.clone();
797                }
798            }
799        }
800
801        for (_, (id, name, arguments)) in calls {
802            let Ok(arguments) = serde_json::from_str(&arguments) else {
803                continue;
804            };
805
806            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
807        }
808
809        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
810            usage: final_usage.clone()
811        }))
812    });
813
814    Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
815}
816
817// ================================================================
818// DeepSeek Completion API
819// ================================================================
820
821/// `deepseek-chat` completion model
822pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
823/// `deepseek-reasoner` completion model
824pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
825
826// Tests
827#[cfg(test)]
828mod tests {
829
830    use super::*;
831
832    #[test]
833    fn test_deserialize_vec_choice() {
834        let data = r#"[{
835            "finish_reason": "stop",
836            "index": 0,
837            "logprobs": null,
838            "message":{"role":"assistant","content":"Hello, world!"}
839            }]"#;
840
841        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
842        assert_eq!(choices.len(), 1);
843        match &choices.first().unwrap().message {
844            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
845            _ => panic!("Expected assistant message"),
846        }
847    }
848
849    #[test]
850    fn test_deserialize_deepseek_response() {
851        let data = r#"{
852            "choices":[{
853                "finish_reason": "stop",
854                "index": 0,
855                "logprobs": null,
856                "message":{"role":"assistant","content":"Hello, world!"}
857            }],
858            "usage": {
859                "completion_tokens": 0,
860                "prompt_tokens": 0,
861                "prompt_cache_hit_tokens": 0,
862                "prompt_cache_miss_tokens": 0,
863                "total_tokens": 0
864            }
865        }"#;
866
867        let jd = &mut serde_json::Deserializer::from_str(data);
868        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
869        match result {
870            Ok(response) => match &response.choices.first().unwrap().message {
871                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
872                _ => panic!("Expected assistant message"),
873            },
874            Err(err) => {
875                panic!("Deserialization error at {}: {}", err.path(), err);
876            }
877        }
878    }
879
880    #[test]
881    fn test_deserialize_example_response() {
882        let data = r#"
883        {
884            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
885            "object": "chat.completion",
886            "created": 0,
887            "model": "deepseek-chat",
888            "choices": [
889                {
890                    "index": 0,
891                    "message": {
892                        "role": "assistant",
893                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
894                    },
895                    "logprobs": null,
896                    "finish_reason": "stop"
897                }
898            ],
899            "usage": {
900                "prompt_tokens": 13,
901                "completion_tokens": 32,
902                "total_tokens": 45,
903                "prompt_tokens_details": {
904                    "cached_tokens": 0
905                },
906                "prompt_cache_hit_tokens": 0,
907                "prompt_cache_miss_tokens": 13
908            },
909            "system_fingerprint": "fp_4b6881f2c5"
910        }
911        "#;
912        let jd = &mut serde_json::Deserializer::from_str(data);
913        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
914
915        match result {
916            Ok(response) => match &response.choices.first().unwrap().message {
917                Message::Assistant { content, .. } => assert_eq!(
918                    content,
919                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
920                ),
921                _ => panic!("Expected assistant message"),
922            },
923            Err(err) => {
924                panic!("Deserialization error at {}: {}", err.path(), err);
925            }
926        }
927    }
928
929    #[test]
930    fn test_serialize_deserialize_tool_call_message() {
931        let tool_call_choice_json = r#"
932            {
933              "finish_reason": "tool_calls",
934              "index": 0,
935              "logprobs": null,
936              "message": {
937                "content": "",
938                "role": "assistant",
939                "tool_calls": [
940                  {
941                    "function": {
942                      "arguments": "{\"x\":2,\"y\":5}",
943                      "name": "subtract"
944                    },
945                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
946                    "index": 0,
947                    "type": "function"
948                  }
949                ]
950              }
951            }
952        "#;
953
954        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
955
956        let expected_choice: Choice = Choice {
957            finish_reason: "tool_calls".to_string(),
958            index: 0,
959            logprobs: None,
960            message: Message::Assistant {
961                content: "".to_string(),
962                name: None,
963                tool_calls: vec![ToolCall {
964                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
965                    function: Function {
966                        name: "subtract".to_string(),
967                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
968                    },
969                    index: 0,
970                    r#type: ToolType::Function,
971                }],
972            },
973        };
974
975        assert_eq!(choice, expected_choice);
976    }
977}