Skip to main content

rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51
52    const VERIFY_PATH: &'static str = "/user/balance";
53
54    fn build<H>(
55        _: &crate::client::ClientBuilder<
56            Self::Builder,
57            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58            H,
59        >,
60    ) -> http_client::Result<Self> {
61        Ok(Self)
62    }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Nothing;
68    type Transcription = Nothing;
69    type ModelListing = Nothing;
70    #[cfg(feature = "image")]
71    type ImageGeneration = Nothing;
72    #[cfg(feature = "audio")]
73    type AudioGeneration = Nothing;
74}
75
76impl DebugExt for DeepSeekExt {}
77
78impl ProviderBuilder for DeepSeekExtBuilder {
79    type Output = DeepSeekExt;
80    type ApiKey = DeepSeekApiKey;
81
82    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
83}
84
85pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
86pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
87
88impl ProviderClient for Client {
89    type Input = DeepSeekApiKey;
90
91    // If you prefer the environment variable approach:
92    fn from_env() -> Self {
93        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
94        let mut client_builder = Self::builder();
95        client_builder.headers_mut().insert(
96            http::header::CONTENT_TYPE,
97            http::HeaderValue::from_static("application/json"),
98        );
99        let client_builder = client_builder.api_key(&api_key);
100        client_builder.build().unwrap()
101    }
102
103    fn from_val(input: Self::Input) -> Self {
104        Self::new(input).unwrap()
105    }
106}
107
108#[derive(Debug, Deserialize)]
109struct ApiErrorResponse {
110    message: String,
111}
112
113#[derive(Debug, Deserialize)]
114#[serde(untagged)]
115enum ApiResponse<T> {
116    Ok(T),
117    Err(ApiErrorResponse),
118}
119
120impl From<ApiErrorResponse> for CompletionError {
121    fn from(err: ApiErrorResponse) -> Self {
122        CompletionError::ProviderError(err.message)
123    }
124}
125
126/// The response shape from the DeepSeek API
127#[derive(Clone, Debug, Serialize, Deserialize)]
128pub struct CompletionResponse {
129    // We'll match the JSON:
130    pub choices: Vec<Choice>,
131    pub usage: Usage,
132    // you may want other fields
133}
134
135#[derive(Clone, Debug, Serialize, Deserialize, Default)]
136pub struct Usage {
137    pub completion_tokens: u32,
138    pub prompt_tokens: u32,
139    pub prompt_cache_hit_tokens: u32,
140    pub prompt_cache_miss_tokens: u32,
141    pub total_tokens: u32,
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub completion_tokens_details: Option<CompletionTokensDetails>,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub prompt_tokens_details: Option<PromptTokensDetails>,
146}
147
148impl Usage {
149    fn new() -> Self {
150        Self {
151            completion_tokens: 0,
152            prompt_tokens: 0,
153            prompt_cache_hit_tokens: 0,
154            prompt_cache_miss_tokens: 0,
155            total_tokens: 0,
156            completion_tokens_details: None,
157            prompt_tokens_details: None,
158        }
159    }
160}
161
162#[derive(Clone, Debug, Serialize, Deserialize, Default)]
163pub struct CompletionTokensDetails {
164    #[serde(skip_serializing_if = "Option::is_none")]
165    pub reasoning_tokens: Option<u32>,
166}
167
168#[derive(Clone, Debug, Serialize, Deserialize, Default)]
169pub struct PromptTokensDetails {
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub cached_tokens: Option<u32>,
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
175pub struct Choice {
176    pub index: usize,
177    pub message: Message,
178    pub logprobs: Option<serde_json::Value>,
179    pub finish_reason: String,
180}
181
182#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
183#[serde(tag = "role", rename_all = "lowercase")]
184pub enum Message {
185    System {
186        content: String,
187        #[serde(skip_serializing_if = "Option::is_none")]
188        name: Option<String>,
189    },
190    User {
191        content: String,
192        #[serde(skip_serializing_if = "Option::is_none")]
193        name: Option<String>,
194    },
195    Assistant {
196        content: String,
197        #[serde(skip_serializing_if = "Option::is_none")]
198        name: Option<String>,
199        #[serde(
200            default,
201            deserialize_with = "json_utils::null_or_vec",
202            skip_serializing_if = "Vec::is_empty"
203        )]
204        tool_calls: Vec<ToolCall>,
205        /// only exists on `deepseek-reasoner` model at time of addition
206        #[serde(skip_serializing_if = "Option::is_none")]
207        reasoning_content: Option<String>,
208    },
209    #[serde(rename = "tool")]
210    ToolResult {
211        tool_call_id: String,
212        content: String,
213    },
214}
215
216impl Message {
217    pub fn system(content: &str) -> Self {
218        Message::System {
219            content: content.to_owned(),
220            name: None,
221        }
222    }
223}
224
225impl From<message::ToolResult> for Message {
226    fn from(tool_result: message::ToolResult) -> Self {
227        let content = match tool_result.content.first() {
228            message::ToolResultContent::Text(text) => text.text,
229            message::ToolResultContent::Image(_) => String::from("[Image]"),
230        };
231
232        Message::ToolResult {
233            tool_call_id: tool_result.id,
234            content,
235        }
236    }
237}
238
239impl From<message::ToolCall> for ToolCall {
240    fn from(tool_call: message::ToolCall) -> Self {
241        Self {
242            id: tool_call.id,
243            // TODO: update index when we have it
244            index: 0,
245            r#type: ToolType::Function,
246            function: Function {
247                name: tool_call.function.name,
248                arguments: tool_call.function.arguments,
249            },
250        }
251    }
252}
253
254impl TryFrom<message::Message> for Vec<Message> {
255    type Error = message::MessageError;
256
257    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
258        match message {
259            message::Message::User { content } => {
260                // extract tool results
261                let mut messages = vec![];
262
263                let tool_results = content
264                    .clone()
265                    .into_iter()
266                    .filter_map(|content| match content {
267                        message::UserContent::ToolResult(tool_result) => {
268                            Some(Message::from(tool_result))
269                        }
270                        _ => None,
271                    })
272                    .collect::<Vec<_>>();
273
274                messages.extend(tool_results);
275
276                // extract text results
277                let text_messages = content
278                    .into_iter()
279                    .filter_map(|content| match content {
280                        message::UserContent::Text(text) => Some(Message::User {
281                            content: text.text,
282                            name: None,
283                        }),
284                        message::UserContent::Document(Document {
285                            data:
286                                DocumentSourceKind::Base64(content)
287                                | DocumentSourceKind::String(content),
288                            ..
289                        }) => Some(Message::User {
290                            content,
291                            name: None,
292                        }),
293                        _ => None,
294                    })
295                    .collect::<Vec<_>>();
296                messages.extend(text_messages);
297
298                Ok(messages)
299            }
300            message::Message::Assistant { content, .. } => {
301                let mut messages: Vec<Message> = vec![];
302                let mut text_content = String::new();
303                let mut reasoning_content = String::new();
304
305                content.iter().for_each(|content| match content {
306                    message::AssistantContent::Text(text) => {
307                        text_content.push_str(text.text());
308                    }
309                    message::AssistantContent::Reasoning(reasoning) => {
310                        reasoning_content.push_str(&reasoning.display_text());
311                    }
312                    _ => {}
313                });
314
315                messages.push(Message::Assistant {
316                    content: text_content,
317                    name: None,
318                    tool_calls: vec![],
319                    reasoning_content: if reasoning_content.is_empty() {
320                        None
321                    } else {
322                        Some(reasoning_content)
323                    },
324                });
325
326                // extract tool calls
327                let tool_calls = content
328                    .clone()
329                    .into_iter()
330                    .filter_map(|content| match content {
331                        message::AssistantContent::ToolCall(tool_call) => {
332                            Some(ToolCall::from(tool_call))
333                        }
334                        _ => None,
335                    })
336                    .collect::<Vec<_>>();
337
338                // if we have tool calls, we add a new Assistant message with them
339                if !tool_calls.is_empty() {
340                    messages.push(Message::Assistant {
341                        content: "".to_string(),
342                        name: None,
343                        tool_calls,
344                        reasoning_content: None,
345                    });
346                }
347
348                Ok(messages)
349            }
350        }
351    }
352}
353
354#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
355pub struct ToolCall {
356    pub id: String,
357    pub index: usize,
358    #[serde(default)]
359    pub r#type: ToolType,
360    pub function: Function,
361}
362
363#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
364pub struct Function {
365    pub name: String,
366    #[serde(with = "json_utils::stringified_json")]
367    pub arguments: serde_json::Value,
368}
369
370#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
371#[serde(rename_all = "lowercase")]
372pub enum ToolType {
373    #[default]
374    Function,
375}
376
377#[derive(Clone, Debug, Deserialize, Serialize)]
378pub struct ToolDefinition {
379    pub r#type: String,
380    pub function: completion::ToolDefinition,
381}
382
383impl From<crate::completion::ToolDefinition> for ToolDefinition {
384    fn from(tool: crate::completion::ToolDefinition) -> Self {
385        Self {
386            r#type: "function".into(),
387            function: tool,
388        }
389    }
390}
391
392impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
393    type Error = CompletionError;
394
395    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
396        let choice = response.choices.first().ok_or_else(|| {
397            CompletionError::ResponseError("Response contained no choices".to_owned())
398        })?;
399        let content = match &choice.message {
400            Message::Assistant {
401                content,
402                tool_calls,
403                reasoning_content,
404                ..
405            } => {
406                let mut content = if content.trim().is_empty() {
407                    vec![]
408                } else {
409                    vec![completion::AssistantContent::text(content)]
410                };
411
412                content.extend(
413                    tool_calls
414                        .iter()
415                        .map(|call| {
416                            completion::AssistantContent::tool_call(
417                                &call.id,
418                                &call.function.name,
419                                call.function.arguments.clone(),
420                            )
421                        })
422                        .collect::<Vec<_>>(),
423                );
424
425                if let Some(reasoning_content) = reasoning_content {
426                    content.push(completion::AssistantContent::reasoning(reasoning_content));
427                }
428
429                Ok(content)
430            }
431            _ => Err(CompletionError::ResponseError(
432                "Response did not contain a valid message or tool call".into(),
433            )),
434        }?;
435
436        let choice = OneOrMany::many(content).map_err(|_| {
437            CompletionError::ResponseError(
438                "Response contained no message or tool call (empty)".to_owned(),
439            )
440        })?;
441
442        let usage = completion::Usage {
443            input_tokens: response.usage.prompt_tokens as u64,
444            output_tokens: response.usage.completion_tokens as u64,
445            total_tokens: response.usage.total_tokens as u64,
446            cached_input_tokens: response
447                .usage
448                .prompt_tokens_details
449                .as_ref()
450                .and_then(|d| d.cached_tokens)
451                .map(|c| c as u64)
452                .unwrap_or(0),
453        };
454
455        Ok(completion::CompletionResponse {
456            choice,
457            usage,
458            raw_response: response,
459            message_id: None,
460        })
461    }
462}
463
464#[derive(Debug, Serialize, Deserialize)]
465pub(super) struct DeepseekCompletionRequest {
466    model: String,
467    pub messages: Vec<Message>,
468    #[serde(skip_serializing_if = "Option::is_none")]
469    temperature: Option<f64>,
470    #[serde(skip_serializing_if = "Vec::is_empty")]
471    tools: Vec<ToolDefinition>,
472    #[serde(skip_serializing_if = "Option::is_none")]
473    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
474    #[serde(flatten, skip_serializing_if = "Option::is_none")]
475    pub additional_params: Option<serde_json::Value>,
476}
477
478impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
479    type Error = CompletionError;
480
481    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
482        if req.output_schema.is_some() {
483            tracing::warn!("Structured outputs currently not supported for DeepSeek");
484        }
485        let model = req.model.clone().unwrap_or_else(|| model.to_string());
486        let mut full_history: Vec<Message> = match &req.preamble {
487            Some(preamble) => vec![Message::system(preamble)],
488            None => vec![],
489        };
490
491        if let Some(docs) = req.normalized_documents() {
492            let docs: Vec<Message> = docs.try_into()?;
493            full_history.extend(docs);
494        }
495
496        let chat_history: Vec<Message> = req
497            .chat_history
498            .clone()
499            .into_iter()
500            .map(|message| message.try_into())
501            .collect::<Result<Vec<Vec<Message>>, _>>()?
502            .into_iter()
503            .flatten()
504            .collect();
505
506        let mut last_reasoning_content = None;
507        let mut last_assistant_idx = None;
508
509        for message in chat_history {
510            if let Message::Assistant {
511                reasoning_content, ..
512            } = &message
513            {
514                if let Some(content) = reasoning_content {
515                    last_reasoning_content = Some(content.clone());
516                } else {
517                    last_assistant_idx = Some(full_history.len());
518                    full_history.push(message);
519                }
520            } else {
521                full_history.push(message);
522            }
523        }
524
525        // Merge last reasoning content into the last assistant message.
526        // Note that we only need to preserve the last reasoning content.
527        if let (Some(idx), Some(reasoning)) = (last_assistant_idx, last_reasoning_content)
528            && let Message::Assistant {
529                ref mut reasoning_content,
530                ..
531            } = full_history[idx]
532        {
533            *reasoning_content = Some(reasoning);
534        }
535
536        let tool_choice = req
537            .tool_choice
538            .clone()
539            .map(crate::providers::openrouter::ToolChoice::try_from)
540            .transpose()?;
541
542        Ok(Self {
543            model: model.to_string(),
544            messages: full_history,
545            temperature: req.temperature,
546            tools: req
547                .tools
548                .clone()
549                .into_iter()
550                .map(ToolDefinition::from)
551                .collect::<Vec<_>>(),
552            tool_choice,
553            additional_params: req.additional_params,
554        })
555    }
556}
557
558/// The struct implementing the `CompletionModel` trait
559#[derive(Clone)]
560pub struct CompletionModel<T = reqwest::Client> {
561    pub client: Client<T>,
562    pub model: String,
563}
564
565impl<T> completion::CompletionModel for CompletionModel<T>
566where
567    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
568{
569    type Response = CompletionResponse;
570    type StreamingResponse = StreamingCompletionResponse;
571
572    type Client = Client<T>;
573
574    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
575        Self {
576            client: client.clone(),
577            model: model.into().to_string(),
578        }
579    }
580
581    async fn completion(
582        &self,
583        completion_request: CompletionRequest,
584    ) -> Result<
585        completion::CompletionResponse<CompletionResponse>,
586        crate::completion::CompletionError,
587    > {
588        let span = if tracing::Span::current().is_disabled() {
589            info_span!(
590                target: "rig::completions",
591                "chat",
592                gen_ai.operation.name = "chat",
593                gen_ai.provider.name = "deepseek",
594                gen_ai.request.model = self.model,
595                gen_ai.system_instructions = tracing::field::Empty,
596                gen_ai.response.id = tracing::field::Empty,
597                gen_ai.response.model = tracing::field::Empty,
598                gen_ai.usage.output_tokens = tracing::field::Empty,
599                gen_ai.usage.input_tokens = tracing::field::Empty,
600            )
601        } else {
602            tracing::Span::current()
603        };
604
605        span.record("gen_ai.system_instructions", &completion_request.preamble);
606
607        let request =
608            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
609
610        if enabled!(Level::TRACE) {
611            tracing::trace!(target: "rig::completions",
612                "DeepSeek completion request: {}",
613                serde_json::to_string_pretty(&request)?
614            );
615        }
616
617        let body = serde_json::to_vec(&request)?;
618        let req = self
619            .client
620            .post("/chat/completions")?
621            .body(body)
622            .map_err(|e| CompletionError::HttpError(e.into()))?;
623
624        async move {
625            let response = self.client.send::<_, Bytes>(req).await?;
626            let status = response.status();
627            let response_body = response.into_body().into_future().await?.to_vec();
628
629            if status.is_success() {
630                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
631                    ApiResponse::Ok(response) => {
632                        let span = tracing::Span::current();
633                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
634                        span.record(
635                            "gen_ai.usage.output_tokens",
636                            response.usage.completion_tokens,
637                        );
638                        if enabled!(Level::TRACE) {
639                            tracing::trace!(target: "rig::completions",
640                                "DeepSeek completion response: {}",
641                                serde_json::to_string_pretty(&response)?
642                            );
643                        }
644                        response.try_into()
645                    }
646                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
647                }
648            } else {
649                Err(CompletionError::ProviderError(
650                    String::from_utf8_lossy(&response_body).to_string(),
651                ))
652            }
653        }
654        .instrument(span)
655        .await
656    }
657
658    async fn stream(
659        &self,
660        completion_request: CompletionRequest,
661    ) -> Result<
662        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
663        CompletionError,
664    > {
665        let preamble = completion_request.preamble.clone();
666        let mut request =
667            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
668
669        let params = json_utils::merge(
670            request.additional_params.unwrap_or(serde_json::json!({})),
671            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
672        );
673
674        request.additional_params = Some(params);
675
676        if enabled!(Level::TRACE) {
677            tracing::trace!(target: "rig::completions",
678                "DeepSeek streaming completion request: {}",
679                serde_json::to_string_pretty(&request)?
680            );
681        }
682
683        let body = serde_json::to_vec(&request)?;
684
685        let req = self
686            .client
687            .post("/chat/completions")?
688            .body(body)
689            .map_err(|e| CompletionError::HttpError(e.into()))?;
690
691        let span = if tracing::Span::current().is_disabled() {
692            info_span!(
693                target: "rig::completions",
694                "chat_streaming",
695                gen_ai.operation.name = "chat_streaming",
696                gen_ai.provider.name = "deepseek",
697                gen_ai.request.model = self.model,
698                gen_ai.system_instructions = preamble,
699                gen_ai.response.id = tracing::field::Empty,
700                gen_ai.response.model = tracing::field::Empty,
701                gen_ai.usage.output_tokens = tracing::field::Empty,
702                gen_ai.usage.input_tokens = tracing::field::Empty,
703            )
704        } else {
705            tracing::Span::current()
706        };
707
708        tracing::Instrument::instrument(
709            send_compatible_streaming_request(self.client.clone(), req),
710            span,
711        )
712        .await
713    }
714}
715
716#[derive(Deserialize, Debug)]
717pub struct StreamingDelta {
718    #[serde(default)]
719    content: Option<String>,
720    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
721    tool_calls: Vec<StreamingToolCall>,
722    reasoning_content: Option<String>,
723}
724
725#[derive(Deserialize, Debug)]
726struct StreamingChoice {
727    delta: StreamingDelta,
728}
729
730#[derive(Deserialize, Debug)]
731struct StreamingCompletionChunk {
732    choices: Vec<StreamingChoice>,
733    usage: Option<Usage>,
734}
735
736#[derive(Clone, Deserialize, Serialize, Debug)]
737pub struct StreamingCompletionResponse {
738    pub usage: Usage,
739}
740
741impl GetTokenUsage for StreamingCompletionResponse {
742    fn token_usage(&self) -> Option<crate::completion::Usage> {
743        let mut usage = crate::completion::Usage::new();
744        usage.input_tokens = self.usage.prompt_tokens as u64;
745        usage.output_tokens = self.usage.completion_tokens as u64;
746        usage.total_tokens = self.usage.total_tokens as u64;
747        usage.cached_input_tokens = self
748            .usage
749            .prompt_tokens_details
750            .as_ref()
751            .and_then(|d| d.cached_tokens)
752            .map(|c| c as u64)
753            .unwrap_or(0);
754
755        Some(usage)
756    }
757}
758
759pub async fn send_compatible_streaming_request<T>(
760    http_client: T,
761    req: Request<Vec<u8>>,
762) -> Result<
763    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
764    CompletionError,
765>
766where
767    T: HttpClientExt + Clone + 'static,
768{
769    let mut event_source = GenericEventSource::new(http_client, req);
770
771    let stream = stream! {
772        let mut final_usage = Usage::new();
773        let mut text_response = String::new();
774        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
775
776        while let Some(event_result) = event_source.next().await {
777            match event_result {
778                Ok(Event::Open) => {
779                    tracing::trace!("SSE connection opened");
780                    continue;
781                }
782                Ok(Event::Message(message)) => {
783                    if message.data.trim().is_empty() || message.data == "[DONE]" {
784                        continue;
785                    }
786
787                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
788                    let Ok(data) = parsed else {
789                        let err = parsed.unwrap_err();
790                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
791                        continue;
792                    };
793
794                    if let Some(choice) = data.choices.first() {
795                        let delta = &choice.delta;
796
797                        if !delta.tool_calls.is_empty() {
798                            for tool_call in &delta.tool_calls {
799                                let function = &tool_call.function;
800
801                                // Start of tool call
802                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
803                                    && empty_or_none(&function.arguments)
804                                {
805                                    let id = tool_call.id.clone().unwrap_or_default();
806                                    let name = function.name.clone().unwrap();
807                                    calls.insert(tool_call.index, (id, name, String::new()));
808                                }
809                                // Continuation of tool call
810                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
811                                    && let Some(arguments) = &function.arguments
812                                    && !arguments.is_empty()
813                                {
814                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
815                                        let combined = format!("{}{}", existing_args, arguments);
816                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
817                                    } else {
818                                        tracing::debug!("Partial tool call received but tool call was never started.");
819                                    }
820                                }
821                                // Complete tool call
822                                else {
823                                    let id = tool_call.id.clone().unwrap_or_default();
824                                    let name = function.name.clone().unwrap_or_default();
825                                    let arguments_str = function.arguments.clone().unwrap_or_default();
826
827                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
828                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
829                                        continue;
830                                    };
831
832                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
833                                        crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
834                                    ));
835                                }
836                            }
837                        }
838
839                        // DeepSeek-specific reasoning stream
840                        if let Some(content) = &delta.reasoning_content {
841                            yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
842                                id: None,
843                                reasoning: content.to_string()
844                            });
845                        }
846
847                        if let Some(content) = &delta.content {
848                            text_response += content;
849                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
850                        }
851                    }
852
853                    if let Some(usage) = data.usage {
854                        final_usage = usage.clone();
855                    }
856                }
857                Err(crate::http_client::Error::StreamEnded) => {
858                    break;
859                }
860                Err(err) => {
861                    tracing::error!(?err, "SSE error");
862                    yield Err(CompletionError::ResponseError(err.to_string()));
863                    break;
864                }
865            }
866        }
867
868        event_source.close();
869
870        let mut tool_calls = Vec::new();
871        // Flush accumulated tool calls
872        for (index, (id, name, arguments)) in calls {
873            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
874                continue;
875            };
876
877            tool_calls.push(ToolCall {
878                id: id.clone(),
879                index,
880                r#type: ToolType::Function,
881                function: Function {
882                    name: name.clone(),
883                    arguments: arguments_json.clone()
884                }
885            });
886            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
887                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
888            ));
889        }
890
891        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
892            StreamingCompletionResponse { usage: final_usage.clone() }
893        ));
894    };
895
896    Ok(crate::streaming::StreamingCompletionResponse::stream(
897        Box::pin(stream),
898    ))
899}
900
901// ================================================================
902// DeepSeek Completion API
903// ================================================================
904pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
905pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
906
907// Tests
908#[cfg(test)]
909mod tests {
910    use super::*;
911
912    #[test]
913    fn test_deserialize_vec_choice() {
914        let data = r#"[{
915            "finish_reason": "stop",
916            "index": 0,
917            "logprobs": null,
918            "message":{"role":"assistant","content":"Hello, world!"}
919            }]"#;
920
921        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
922        assert_eq!(choices.len(), 1);
923        match &choices.first().unwrap().message {
924            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
925            _ => panic!("Expected assistant message"),
926        }
927    }
928
929    #[test]
930    fn test_deserialize_deepseek_response() {
931        let data = r#"{
932            "choices":[{
933                "finish_reason": "stop",
934                "index": 0,
935                "logprobs": null,
936                "message":{"role":"assistant","content":"Hello, world!"}
937            }],
938            "usage": {
939                "completion_tokens": 0,
940                "prompt_tokens": 0,
941                "prompt_cache_hit_tokens": 0,
942                "prompt_cache_miss_tokens": 0,
943                "total_tokens": 0
944            }
945        }"#;
946
947        let jd = &mut serde_json::Deserializer::from_str(data);
948        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
949        match result {
950            Ok(response) => match &response.choices.first().unwrap().message {
951                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
952                _ => panic!("Expected assistant message"),
953            },
954            Err(err) => {
955                panic!("Deserialization error at {}: {}", err.path(), err);
956            }
957        }
958    }
959
960    #[test]
961    fn test_deserialize_example_response() {
962        let data = r#"
963        {
964            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
965            "object": "chat.completion",
966            "created": 0,
967            "model": "deepseek-chat",
968            "choices": [
969                {
970                    "index": 0,
971                    "message": {
972                        "role": "assistant",
973                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
974                    },
975                    "logprobs": null,
976                    "finish_reason": "stop"
977                }
978            ],
979            "usage": {
980                "prompt_tokens": 13,
981                "completion_tokens": 32,
982                "total_tokens": 45,
983                "prompt_tokens_details": {
984                    "cached_tokens": 0
985                },
986                "prompt_cache_hit_tokens": 0,
987                "prompt_cache_miss_tokens": 13
988            },
989            "system_fingerprint": "fp_4b6881f2c5"
990        }
991        "#;
992        let jd = &mut serde_json::Deserializer::from_str(data);
993        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
994
995        match result {
996            Ok(response) => match &response.choices.first().unwrap().message {
997                Message::Assistant { content, .. } => assert_eq!(
998                    content,
999                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
1000                ),
1001                _ => panic!("Expected assistant message"),
1002            },
1003            Err(err) => {
1004                panic!("Deserialization error at {}: {}", err.path(), err);
1005            }
1006        }
1007    }
1008
1009    #[test]
1010    fn test_serialize_deserialize_tool_call_message() {
1011        let tool_call_choice_json = r#"
1012            {
1013              "finish_reason": "tool_calls",
1014              "index": 0,
1015              "logprobs": null,
1016              "message": {
1017                "content": "",
1018                "role": "assistant",
1019                "tool_calls": [
1020                  {
1021                    "function": {
1022                      "arguments": "{\"x\":2,\"y\":5}",
1023                      "name": "subtract"
1024                    },
1025                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1026                    "index": 0,
1027                    "type": "function"
1028                  }
1029                ]
1030              }
1031            }
1032        "#;
1033
1034        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1035
1036        let expected_choice: Choice = Choice {
1037            finish_reason: "tool_calls".to_string(),
1038            index: 0,
1039            logprobs: None,
1040            message: Message::Assistant {
1041                content: "".to_string(),
1042                name: None,
1043                tool_calls: vec![ToolCall {
1044                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1045                    function: Function {
1046                        name: "subtract".to_string(),
1047                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1048                    },
1049                    index: 0,
1050                    r#type: ToolType::Function,
1051                }],
1052                reasoning_content: None,
1053            },
1054        };
1055
1056        assert_eq!(choice, expected_choice);
1057    }
1058    #[test]
1059    fn test_client_initialization() {
1060        let _client: crate::providers::deepseek::Client =
1061            crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1062        let _client_from_builder: crate::providers::deepseek::Client =
1063            crate::providers::deepseek::Client::builder()
1064                .api_key("dummy-key")
1065                .build()
1066                .expect("Client::builder() failed");
1067    }
1068}