Skip to main content

rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51    const VERIFY_PATH: &'static str = "/user/balance";
52}
53
54impl<H> Capabilities<H> for DeepSeekExt {
55    type Completion = Capable<CompletionModel<H>>;
56    type Embeddings = Nothing;
57    type Transcription = Nothing;
58    type ModelListing = Nothing;
59    #[cfg(feature = "image")]
60    type ImageGeneration = Nothing;
61    #[cfg(feature = "audio")]
62    type AudioGeneration = Nothing;
63}
64
65impl DebugExt for DeepSeekExt {}
66
67impl ProviderBuilder for DeepSeekExtBuilder {
68    type Extension<H>
69        = DeepSeekExt
70    where
71        H: HttpClientExt;
72    type ApiKey = DeepSeekApiKey;
73
74    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
75
76    fn build<H>(
77        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
78    ) -> http_client::Result<Self::Extension<H>>
79    where
80        H: HttpClientExt,
81    {
82        Ok(DeepSeekExt)
83    }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
87pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
88
89impl ProviderClient for Client {
90    type Input = DeepSeekApiKey;
91
92    // If you prefer the environment variable approach:
93    fn from_env() -> Self {
94        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
95        let mut client_builder = Self::builder();
96        client_builder.headers_mut().insert(
97            http::header::CONTENT_TYPE,
98            http::HeaderValue::from_static("application/json"),
99        );
100        let client_builder = client_builder.api_key(&api_key);
101        client_builder.build().unwrap()
102    }
103
104    fn from_val(input: Self::Input) -> Self {
105        Self::new(input).unwrap()
106    }
107}
108
109#[derive(Debug, Deserialize)]
110struct ApiErrorResponse {
111    message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117    Ok(T),
118    Err(ApiErrorResponse),
119}
120
121impl From<ApiErrorResponse> for CompletionError {
122    fn from(err: ApiErrorResponse) -> Self {
123        CompletionError::ProviderError(err.message)
124    }
125}
126
127/// The response shape from the DeepSeek API
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct CompletionResponse {
130    // We'll match the JSON:
131    pub choices: Vec<Choice>,
132    pub usage: Usage,
133    // you may want other fields
134}
135
136#[derive(Clone, Debug, Serialize, Deserialize, Default)]
137pub struct Usage {
138    pub completion_tokens: u32,
139    pub prompt_tokens: u32,
140    pub prompt_cache_hit_tokens: u32,
141    pub prompt_cache_miss_tokens: u32,
142    pub total_tokens: u32,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub completion_tokens_details: Option<CompletionTokensDetails>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub prompt_tokens_details: Option<PromptTokensDetails>,
147}
148
149impl Usage {
150    fn new() -> Self {
151        Self {
152            completion_tokens: 0,
153            prompt_tokens: 0,
154            prompt_cache_hit_tokens: 0,
155            prompt_cache_miss_tokens: 0,
156            total_tokens: 0,
157            completion_tokens_details: None,
158            prompt_tokens_details: None,
159        }
160    }
161}
162
163#[derive(Clone, Debug, Serialize, Deserialize, Default)]
164pub struct CompletionTokensDetails {
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub reasoning_tokens: Option<u32>,
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default)]
170pub struct PromptTokensDetails {
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub cached_tokens: Option<u32>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
176pub struct Choice {
177    pub index: usize,
178    pub message: Message,
179    pub logprobs: Option<serde_json::Value>,
180    pub finish_reason: String,
181}
182
183#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
184#[serde(tag = "role", rename_all = "lowercase")]
185pub enum Message {
186    System {
187        content: String,
188        #[serde(skip_serializing_if = "Option::is_none")]
189        name: Option<String>,
190    },
191    User {
192        content: String,
193        #[serde(skip_serializing_if = "Option::is_none")]
194        name: Option<String>,
195    },
196    Assistant {
197        content: String,
198        #[serde(skip_serializing_if = "Option::is_none")]
199        name: Option<String>,
200        #[serde(
201            default,
202            deserialize_with = "json_utils::null_or_vec",
203            skip_serializing_if = "Vec::is_empty"
204        )]
205        tool_calls: Vec<ToolCall>,
206        /// only exists on `deepseek-reasoner` model at time of addition
207        #[serde(skip_serializing_if = "Option::is_none")]
208        reasoning_content: Option<String>,
209    },
210    #[serde(rename = "tool")]
211    ToolResult {
212        tool_call_id: String,
213        content: String,
214    },
215}
216
217impl Message {
218    pub fn system(content: &str) -> Self {
219        Message::System {
220            content: content.to_owned(),
221            name: None,
222        }
223    }
224}
225
226impl From<message::ToolResult> for Message {
227    fn from(tool_result: message::ToolResult) -> Self {
228        let content = match tool_result.content.first() {
229            message::ToolResultContent::Text(text) => text.text,
230            message::ToolResultContent::Image(_) => String::from("[Image]"),
231        };
232
233        Message::ToolResult {
234            tool_call_id: tool_result.id,
235            content,
236        }
237    }
238}
239
240impl From<message::ToolCall> for ToolCall {
241    fn from(tool_call: message::ToolCall) -> Self {
242        Self {
243            id: tool_call.id,
244            // TODO: update index when we have it
245            index: 0,
246            r#type: ToolType::Function,
247            function: Function {
248                name: tool_call.function.name,
249                arguments: tool_call.function.arguments,
250            },
251        }
252    }
253}
254
255impl TryFrom<message::Message> for Vec<Message> {
256    type Error = message::MessageError;
257
258    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
259        match message {
260            message::Message::System { content } => Ok(vec![Message::System {
261                content,
262                name: None,
263            }]),
264            message::Message::User { content } => {
265                // extract tool results
266                let mut messages = vec![];
267
268                let tool_results = content
269                    .clone()
270                    .into_iter()
271                    .filter_map(|content| match content {
272                        message::UserContent::ToolResult(tool_result) => {
273                            Some(Message::from(tool_result))
274                        }
275                        _ => None,
276                    })
277                    .collect::<Vec<_>>();
278
279                messages.extend(tool_results);
280
281                let text_content: String = content
282                    .into_iter()
283                    .filter_map(|content| match content {
284                        message::UserContent::Text(text) => Some(text.text),
285                        message::UserContent::Document(Document {
286                            data:
287                                DocumentSourceKind::Base64(content)
288                                | DocumentSourceKind::String(content),
289                            ..
290                        }) => Some(content),
291                        _ => None,
292                    })
293                    .collect::<Vec<_>>()
294                    .join("\n");
295
296                if !text_content.is_empty() {
297                    messages.push(Message::User {
298                        content: text_content,
299                        name: None,
300                    });
301                }
302
303                Ok(messages)
304            }
305            message::Message::Assistant { content, .. } => {
306                let mut text_content = String::new();
307                let mut reasoning_content = String::new();
308                let mut tool_calls = Vec::new();
309
310                for item in content.iter() {
311                    match item {
312                        message::AssistantContent::Text(text) => {
313                            text_content.push_str(text.text());
314                        }
315                        message::AssistantContent::Reasoning(reasoning) => {
316                            reasoning_content.push_str(&reasoning.display_text());
317                        }
318                        message::AssistantContent::ToolCall(tool_call) => {
319                            tool_calls.push(ToolCall::from(tool_call.clone()));
320                        }
321                        _ => {}
322                    }
323                }
324
325                let reasoning = if reasoning_content.is_empty() {
326                    None
327                } else {
328                    Some(reasoning_content)
329                };
330
331                Ok(vec![Message::Assistant {
332                    content: text_content,
333                    name: None,
334                    tool_calls,
335                    reasoning_content: reasoning,
336                }])
337            }
338        }
339    }
340}
341
342#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
343pub struct ToolCall {
344    pub id: String,
345    pub index: usize,
346    #[serde(default)]
347    pub r#type: ToolType,
348    pub function: Function,
349}
350
351#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
352pub struct Function {
353    pub name: String,
354    #[serde(with = "json_utils::stringified_json")]
355    pub arguments: serde_json::Value,
356}
357
358#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
359#[serde(rename_all = "lowercase")]
360pub enum ToolType {
361    #[default]
362    Function,
363}
364
365#[derive(Clone, Debug, Deserialize, Serialize)]
366pub struct ToolDefinition {
367    pub r#type: String,
368    pub function: completion::ToolDefinition,
369}
370
371impl From<crate::completion::ToolDefinition> for ToolDefinition {
372    fn from(tool: crate::completion::ToolDefinition) -> Self {
373        Self {
374            r#type: "function".into(),
375            function: tool,
376        }
377    }
378}
379
380impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
381    type Error = CompletionError;
382
383    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
384        let choice = response.choices.first().ok_or_else(|| {
385            CompletionError::ResponseError("Response contained no choices".to_owned())
386        })?;
387        let content = match &choice.message {
388            Message::Assistant {
389                content,
390                tool_calls,
391                reasoning_content,
392                ..
393            } => {
394                let mut content = if content.trim().is_empty() {
395                    vec![]
396                } else {
397                    vec![completion::AssistantContent::text(content)]
398                };
399
400                content.extend(
401                    tool_calls
402                        .iter()
403                        .map(|call| {
404                            completion::AssistantContent::tool_call(
405                                &call.id,
406                                &call.function.name,
407                                call.function.arguments.clone(),
408                            )
409                        })
410                        .collect::<Vec<_>>(),
411                );
412
413                if let Some(reasoning_content) = reasoning_content {
414                    content.push(completion::AssistantContent::reasoning(reasoning_content));
415                }
416
417                Ok(content)
418            }
419            _ => Err(CompletionError::ResponseError(
420                "Response did not contain a valid message or tool call".into(),
421            )),
422        }?;
423
424        let choice = OneOrMany::many(content).map_err(|_| {
425            CompletionError::ResponseError(
426                "Response contained no message or tool call (empty)".to_owned(),
427            )
428        })?;
429
430        let usage = completion::Usage {
431            input_tokens: response.usage.prompt_tokens as u64,
432            output_tokens: response.usage.completion_tokens as u64,
433            total_tokens: response.usage.total_tokens as u64,
434            cached_input_tokens: response
435                .usage
436                .prompt_tokens_details
437                .as_ref()
438                .and_then(|d| d.cached_tokens)
439                .map(|c| c as u64)
440                .unwrap_or(0),
441        };
442
443        Ok(completion::CompletionResponse {
444            choice,
445            usage,
446            raw_response: response,
447            message_id: None,
448        })
449    }
450}
451
452#[derive(Debug, Serialize, Deserialize)]
453pub(super) struct DeepseekCompletionRequest {
454    model: String,
455    pub messages: Vec<Message>,
456    #[serde(skip_serializing_if = "Option::is_none")]
457    temperature: Option<f64>,
458    #[serde(skip_serializing_if = "Vec::is_empty")]
459    tools: Vec<ToolDefinition>,
460    #[serde(skip_serializing_if = "Option::is_none")]
461    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
462    #[serde(flatten, skip_serializing_if = "Option::is_none")]
463    pub additional_params: Option<serde_json::Value>,
464}
465
466impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
467    type Error = CompletionError;
468
469    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
470        if req.output_schema.is_some() {
471            tracing::warn!("Structured outputs currently not supported for DeepSeek");
472        }
473        let model = req.model.clone().unwrap_or_else(|| model.to_string());
474        let mut full_history: Vec<Message> = match &req.preamble {
475            Some(preamble) => vec![Message::system(preamble)],
476            None => vec![],
477        };
478
479        if let Some(docs) = req.normalized_documents() {
480            let docs: Vec<Message> = docs.try_into()?;
481            full_history.extend(docs);
482        }
483
484        let chat_history: Vec<Message> = req
485            .chat_history
486            .clone()
487            .into_iter()
488            .map(|message| message.try_into())
489            .collect::<Result<Vec<Vec<Message>>, _>>()?
490            .into_iter()
491            .flatten()
492            .collect();
493
494        full_history.extend(chat_history);
495
496        let tool_choice = req
497            .tool_choice
498            .clone()
499            .map(crate::providers::openrouter::ToolChoice::try_from)
500            .transpose()?;
501
502        Ok(Self {
503            model: model.to_string(),
504            messages: full_history,
505            temperature: req.temperature,
506            tools: req
507                .tools
508                .clone()
509                .into_iter()
510                .map(ToolDefinition::from)
511                .collect::<Vec<_>>(),
512            tool_choice,
513            additional_params: req.additional_params,
514        })
515    }
516}
517
518/// The struct implementing the `CompletionModel` trait
519#[derive(Clone)]
520pub struct CompletionModel<T = reqwest::Client> {
521    pub client: Client<T>,
522    pub model: String,
523}
524
525impl<T> completion::CompletionModel for CompletionModel<T>
526where
527    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
528{
529    type Response = CompletionResponse;
530    type StreamingResponse = StreamingCompletionResponse;
531
532    type Client = Client<T>;
533
534    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
535        Self {
536            client: client.clone(),
537            model: model.into().to_string(),
538        }
539    }
540
541    async fn completion(
542        &self,
543        completion_request: CompletionRequest,
544    ) -> Result<
545        completion::CompletionResponse<CompletionResponse>,
546        crate::completion::CompletionError,
547    > {
548        let span = if tracing::Span::current().is_disabled() {
549            info_span!(
550                target: "rig::completions",
551                "chat",
552                gen_ai.operation.name = "chat",
553                gen_ai.provider.name = "deepseek",
554                gen_ai.request.model = self.model,
555                gen_ai.system_instructions = tracing::field::Empty,
556                gen_ai.response.id = tracing::field::Empty,
557                gen_ai.response.model = tracing::field::Empty,
558                gen_ai.usage.output_tokens = tracing::field::Empty,
559                gen_ai.usage.input_tokens = tracing::field::Empty,
560                gen_ai.usage.cached_tokens = tracing::field::Empty,
561            )
562        } else {
563            tracing::Span::current()
564        };
565
566        span.record("gen_ai.system_instructions", &completion_request.preamble);
567
568        let request =
569            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
570
571        if enabled!(Level::TRACE) {
572            tracing::trace!(target: "rig::completions",
573                "DeepSeek completion request: {}",
574                serde_json::to_string_pretty(&request)?
575            );
576        }
577
578        let body = serde_json::to_vec(&request)?;
579        let req = self
580            .client
581            .post("/chat/completions")?
582            .body(body)
583            .map_err(|e| CompletionError::HttpError(e.into()))?;
584
585        async move {
586            let response = self.client.send::<_, Bytes>(req).await?;
587            let status = response.status();
588            let response_body = response.into_body().into_future().await?.to_vec();
589
590            if status.is_success() {
591                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
592                    ApiResponse::Ok(response) => {
593                        let span = tracing::Span::current();
594                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
595                        span.record(
596                            "gen_ai.usage.output_tokens",
597                            response.usage.completion_tokens,
598                        );
599                        span.record(
600                            "gen_ai.usage.cached_tokens",
601                            response
602                                .usage
603                                .prompt_tokens_details
604                                .as_ref()
605                                .and_then(|d| d.cached_tokens)
606                                .unwrap_or(0),
607                        );
608                        if enabled!(Level::TRACE) {
609                            tracing::trace!(target: "rig::completions",
610                                "DeepSeek completion response: {}",
611                                serde_json::to_string_pretty(&response)?
612                            );
613                        }
614                        response.try_into()
615                    }
616                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
617                }
618            } else {
619                Err(CompletionError::ProviderError(
620                    String::from_utf8_lossy(&response_body).to_string(),
621                ))
622            }
623        }
624        .instrument(span)
625        .await
626    }
627
628    async fn stream(
629        &self,
630        completion_request: CompletionRequest,
631    ) -> Result<
632        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
633        CompletionError,
634    > {
635        let preamble = completion_request.preamble.clone();
636        let mut request =
637            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
638
639        let params = json_utils::merge(
640            request.additional_params.unwrap_or(serde_json::json!({})),
641            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
642        );
643
644        request.additional_params = Some(params);
645
646        if enabled!(Level::TRACE) {
647            tracing::trace!(target: "rig::completions",
648                "DeepSeek streaming completion request: {}",
649                serde_json::to_string_pretty(&request)?
650            );
651        }
652
653        let body = serde_json::to_vec(&request)?;
654
655        let req = self
656            .client
657            .post("/chat/completions")?
658            .body(body)
659            .map_err(|e| CompletionError::HttpError(e.into()))?;
660
661        let span = if tracing::Span::current().is_disabled() {
662            info_span!(
663                target: "rig::completions",
664                "chat_streaming",
665                gen_ai.operation.name = "chat_streaming",
666                gen_ai.provider.name = "deepseek",
667                gen_ai.request.model = self.model,
668                gen_ai.system_instructions = preamble,
669                gen_ai.response.id = tracing::field::Empty,
670                gen_ai.response.model = tracing::field::Empty,
671                gen_ai.usage.output_tokens = tracing::field::Empty,
672                gen_ai.usage.input_tokens = tracing::field::Empty,
673                gen_ai.usage.cached_tokens = tracing::field::Empty,
674            )
675        } else {
676            tracing::Span::current()
677        };
678
679        tracing::Instrument::instrument(
680            send_compatible_streaming_request(self.client.clone(), req),
681            span,
682        )
683        .await
684    }
685}
686
687#[derive(Deserialize, Debug)]
688pub struct StreamingDelta {
689    #[serde(default)]
690    content: Option<String>,
691    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
692    tool_calls: Vec<StreamingToolCall>,
693    reasoning_content: Option<String>,
694}
695
696#[derive(Deserialize, Debug)]
697struct StreamingChoice {
698    delta: StreamingDelta,
699}
700
701#[derive(Deserialize, Debug)]
702struct StreamingCompletionChunk {
703    choices: Vec<StreamingChoice>,
704    usage: Option<Usage>,
705}
706
707#[derive(Clone, Deserialize, Serialize, Debug)]
708pub struct StreamingCompletionResponse {
709    pub usage: Usage,
710}
711
712impl GetTokenUsage for StreamingCompletionResponse {
713    fn token_usage(&self) -> Option<crate::completion::Usage> {
714        let mut usage = crate::completion::Usage::new();
715        usage.input_tokens = self.usage.prompt_tokens as u64;
716        usage.output_tokens = self.usage.completion_tokens as u64;
717        usage.total_tokens = self.usage.total_tokens as u64;
718        usage.cached_input_tokens = self
719            .usage
720            .prompt_tokens_details
721            .as_ref()
722            .and_then(|d| d.cached_tokens)
723            .map(|c| c as u64)
724            .unwrap_or(0);
725
726        Some(usage)
727    }
728}
729
730pub async fn send_compatible_streaming_request<T>(
731    http_client: T,
732    req: Request<Vec<u8>>,
733) -> Result<
734    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
735    CompletionError,
736>
737where
738    T: HttpClientExt + Clone + 'static,
739{
740    let mut event_source = GenericEventSource::new(http_client, req);
741
742    let stream = stream! {
743        let mut final_usage = Usage::new();
744        let mut text_response = String::new();
745        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
746
747        while let Some(event_result) = event_source.next().await {
748            match event_result {
749                Ok(Event::Open) => {
750                    tracing::trace!("SSE connection opened");
751                    continue;
752                }
753                Ok(Event::Message(message)) => {
754                    if message.data.trim().is_empty() || message.data == "[DONE]" {
755                        continue;
756                    }
757
758                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
759                    let Ok(data) = parsed else {
760                        let err = parsed.unwrap_err();
761                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
762                        continue;
763                    };
764
765                    if let Some(choice) = data.choices.first() {
766                        let delta = &choice.delta;
767
768                        if !delta.tool_calls.is_empty() {
769                            for tool_call in &delta.tool_calls {
770                                let function = &tool_call.function;
771
772                                // Start of tool call
773                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
774                                    && empty_or_none(&function.arguments)
775                                {
776                                    let id = tool_call.id.clone().unwrap_or_default();
777                                    let name = function.name.clone().unwrap();
778                                    calls.insert(tool_call.index, (id, name, String::new()));
779                                }
780                                // Continuation of tool call
781                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
782                                    && let Some(arguments) = &function.arguments
783                                    && !arguments.is_empty()
784                                {
785                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
786                                        let combined = format!("{}{}", existing_args, arguments);
787                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
788                                    } else {
789                                        tracing::debug!("Partial tool call received but tool call was never started.");
790                                    }
791                                }
792                                // Complete tool call
793                                else {
794                                    let id = tool_call.id.clone().unwrap_or_default();
795                                    let name = function.name.clone().unwrap_or_default();
796                                    let arguments_str = function.arguments.clone().unwrap_or_default();
797
798                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
799                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
800                                        continue;
801                                    };
802
803                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
804                                        crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
805                                    ));
806                                }
807                            }
808                        }
809
810                        // DeepSeek-specific reasoning stream
811                        if let Some(content) = &delta.reasoning_content {
812                            yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
813                                id: None,
814                                reasoning: content.to_string()
815                            });
816                        }
817
818                        if let Some(content) = &delta.content {
819                            text_response += content;
820                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
821                        }
822                    }
823
824                    if let Some(usage) = data.usage {
825                        final_usage = usage.clone();
826                    }
827                }
828                Err(crate::http_client::Error::StreamEnded) => {
829                    break;
830                }
831                Err(err) => {
832                    tracing::error!(?err, "SSE error");
833                    yield Err(CompletionError::ResponseError(err.to_string()));
834                    break;
835                }
836            }
837        }
838
839        event_source.close();
840
841        let mut tool_calls = Vec::new();
842        // Flush accumulated tool calls
843        for (index, (id, name, arguments)) in calls {
844            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
845                continue;
846            };
847
848            tool_calls.push(ToolCall {
849                id: id.clone(),
850                index,
851                r#type: ToolType::Function,
852                function: Function {
853                    name: name.clone(),
854                    arguments: arguments_json.clone()
855                }
856            });
857            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
858                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
859            ));
860        }
861
862        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
863            StreamingCompletionResponse { usage: final_usage.clone() }
864        ));
865    };
866
867    Ok(crate::streaming::StreamingCompletionResponse::stream(
868        Box::pin(stream),
869    ))
870}
871
872// ================================================================
873// DeepSeek Completion API
874// ================================================================
875pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
876pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
877
878// Tests
879#[cfg(test)]
880mod tests {
881    use super::*;
882
883    #[test]
884    fn test_deserialize_vec_choice() {
885        let data = r#"[{
886            "finish_reason": "stop",
887            "index": 0,
888            "logprobs": null,
889            "message":{"role":"assistant","content":"Hello, world!"}
890            }]"#;
891
892        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
893        assert_eq!(choices.len(), 1);
894        match &choices.first().unwrap().message {
895            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
896            _ => panic!("Expected assistant message"),
897        }
898    }
899
900    #[test]
901    fn test_deserialize_deepseek_response() {
902        let data = r#"{
903            "choices":[{
904                "finish_reason": "stop",
905                "index": 0,
906                "logprobs": null,
907                "message":{"role":"assistant","content":"Hello, world!"}
908            }],
909            "usage": {
910                "completion_tokens": 0,
911                "prompt_tokens": 0,
912                "prompt_cache_hit_tokens": 0,
913                "prompt_cache_miss_tokens": 0,
914                "total_tokens": 0
915            }
916        }"#;
917
918        let jd = &mut serde_json::Deserializer::from_str(data);
919        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
920        match result {
921            Ok(response) => match &response.choices.first().unwrap().message {
922                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
923                _ => panic!("Expected assistant message"),
924            },
925            Err(err) => {
926                panic!("Deserialization error at {}: {}", err.path(), err);
927            }
928        }
929    }
930
931    #[test]
932    fn test_deserialize_example_response() {
933        let data = r#"
934        {
935            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
936            "object": "chat.completion",
937            "created": 0,
938            "model": "deepseek-chat",
939            "choices": [
940                {
941                    "index": 0,
942                    "message": {
943                        "role": "assistant",
944                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
945                    },
946                    "logprobs": null,
947                    "finish_reason": "stop"
948                }
949            ],
950            "usage": {
951                "prompt_tokens": 13,
952                "completion_tokens": 32,
953                "total_tokens": 45,
954                "prompt_tokens_details": {
955                    "cached_tokens": 0
956                },
957                "prompt_cache_hit_tokens": 0,
958                "prompt_cache_miss_tokens": 13
959            },
960            "system_fingerprint": "fp_4b6881f2c5"
961        }
962        "#;
963        let jd = &mut serde_json::Deserializer::from_str(data);
964        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
965
966        match result {
967            Ok(response) => match &response.choices.first().unwrap().message {
968                Message::Assistant { content, .. } => assert_eq!(
969                    content,
970                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
971                ),
972                _ => panic!("Expected assistant message"),
973            },
974            Err(err) => {
975                panic!("Deserialization error at {}: {}", err.path(), err);
976            }
977        }
978    }
979
980    #[test]
981    fn test_serialize_deserialize_tool_call_message() {
982        let tool_call_choice_json = r#"
983            {
984              "finish_reason": "tool_calls",
985              "index": 0,
986              "logprobs": null,
987              "message": {
988                "content": "",
989                "role": "assistant",
990                "tool_calls": [
991                  {
992                    "function": {
993                      "arguments": "{\"x\":2,\"y\":5}",
994                      "name": "subtract"
995                    },
996                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
997                    "index": 0,
998                    "type": "function"
999                  }
1000                ]
1001              }
1002            }
1003        "#;
1004
1005        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1006
1007        let expected_choice: Choice = Choice {
1008            finish_reason: "tool_calls".to_string(),
1009            index: 0,
1010            logprobs: None,
1011            message: Message::Assistant {
1012                content: "".to_string(),
1013                name: None,
1014                tool_calls: vec![ToolCall {
1015                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1016                    function: Function {
1017                        name: "subtract".to_string(),
1018                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1019                    },
1020                    index: 0,
1021                    r#type: ToolType::Function,
1022                }],
1023                reasoning_content: None,
1024            },
1025        };
1026
1027        assert_eq!(choice, expected_choice);
1028    }
1029    #[test]
1030    fn test_user_message_multiple_text_items_merged() {
1031        use crate::completion::message::{Message as RigMessage, UserContent};
1032
1033        let rig_msg = RigMessage::User {
1034            content: OneOrMany::many(vec![
1035                UserContent::text("first part"),
1036                UserContent::text("second part"),
1037            ])
1038            .expect("content should not be empty"),
1039        };
1040
1041        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1042
1043        let user_messages: Vec<&Message> = messages
1044            .iter()
1045            .filter(|m| matches!(m, Message::User { .. }))
1046            .collect();
1047
1048        assert_eq!(
1049            user_messages.len(),
1050            1,
1051            "multiple text items should produce a single user message"
1052        );
1053        match &user_messages[0] {
1054            Message::User { content, .. } => {
1055                assert_eq!(content, "first part\nsecond part");
1056            }
1057            _ => unreachable!(),
1058        }
1059    }
1060
1061    #[test]
1062    fn test_assistant_message_with_reasoning_and_tool_calls() {
1063        use crate::completion::message::{AssistantContent, Message as RigMessage};
1064
1065        let rig_msg = RigMessage::Assistant {
1066            id: None,
1067            content: OneOrMany::many(vec![
1068                AssistantContent::reasoning("thinking about the problem"),
1069                AssistantContent::text("I'll call the tool"),
1070                AssistantContent::tool_call(
1071                    "call_1",
1072                    "subtract",
1073                    serde_json::json!({"x": 2, "y": 5}),
1074                ),
1075            ])
1076            .expect("content should not be empty"),
1077        };
1078
1079        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1080
1081        assert_eq!(messages.len(), 1, "should produce exactly one message");
1082        match &messages[0] {
1083            Message::Assistant {
1084                content,
1085                tool_calls,
1086                reasoning_content,
1087                ..
1088            } => {
1089                assert_eq!(content, "I'll call the tool");
1090                assert_eq!(
1091                    reasoning_content.as_deref(),
1092                    Some("thinking about the problem")
1093                );
1094                assert_eq!(tool_calls.len(), 1);
1095                assert_eq!(tool_calls[0].function.name, "subtract");
1096            }
1097            _ => panic!("Expected assistant message"),
1098        }
1099    }
1100
1101    #[test]
1102    fn test_assistant_message_without_reasoning() {
1103        use crate::completion::message::{AssistantContent, Message as RigMessage};
1104
1105        let rig_msg = RigMessage::Assistant {
1106            id: None,
1107            content: OneOrMany::many(vec![
1108                AssistantContent::text("calling tool"),
1109                AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1110            ])
1111            .expect("content should not be empty"),
1112        };
1113
1114        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1115
1116        assert_eq!(messages.len(), 1);
1117        match &messages[0] {
1118            Message::Assistant {
1119                reasoning_content,
1120                tool_calls,
1121                ..
1122            } => {
1123                assert!(reasoning_content.is_none());
1124                assert_eq!(tool_calls.len(), 1);
1125            }
1126            _ => panic!("Expected assistant message"),
1127        }
1128    }
1129
1130    #[test]
1131    fn test_client_initialization() {
1132        let _client =
1133            crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1134        let _client_from_builder = crate::providers::deepseek::Client::builder()
1135            .api_key("dummy-key")
1136            .build()
1137            .expect("Client::builder() failed");
1138    }
1139}