Skip to main content

rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51    const VERIFY_PATH: &'static str = "/user/balance";
52}
53
54impl<H> Capabilities<H> for DeepSeekExt {
55    type Completion = Capable<CompletionModel<H>>;
56    type Embeddings = Nothing;
57    type Transcription = Nothing;
58    type ModelListing = Nothing;
59    #[cfg(feature = "image")]
60    type ImageGeneration = Nothing;
61    #[cfg(feature = "audio")]
62    type AudioGeneration = Nothing;
63}
64
65impl DebugExt for DeepSeekExt {}
66
67impl ProviderBuilder for DeepSeekExtBuilder {
68    type Extension<H>
69        = DeepSeekExt
70    where
71        H: HttpClientExt;
72    type ApiKey = DeepSeekApiKey;
73
74    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
75
76    fn build<H>(
77        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
78    ) -> http_client::Result<Self::Extension<H>>
79    where
80        H: HttpClientExt,
81    {
82        Ok(DeepSeekExt)
83    }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
87pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
88
89impl ProviderClient for Client {
90    type Input = DeepSeekApiKey;
91
92    // If you prefer the environment variable approach:
93    fn from_env() -> Self {
94        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
95        let mut client_builder = Self::builder();
96        client_builder.headers_mut().insert(
97            http::header::CONTENT_TYPE,
98            http::HeaderValue::from_static("application/json"),
99        );
100        let client_builder = client_builder.api_key(&api_key);
101        client_builder.build().unwrap()
102    }
103
104    fn from_val(input: Self::Input) -> Self {
105        Self::new(input).unwrap()
106    }
107}
108
109#[derive(Debug, Deserialize)]
110struct ApiErrorResponse {
111    message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117    Ok(T),
118    Err(ApiErrorResponse),
119}
120
121impl From<ApiErrorResponse> for CompletionError {
122    fn from(err: ApiErrorResponse) -> Self {
123        CompletionError::ProviderError(err.message)
124    }
125}
126
127/// The response shape from the DeepSeek API
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct CompletionResponse {
130    // We'll match the JSON:
131    pub choices: Vec<Choice>,
132    pub usage: Usage,
133    // you may want other fields
134}
135
136#[derive(Clone, Debug, Serialize, Deserialize, Default)]
137pub struct Usage {
138    pub completion_tokens: u32,
139    pub prompt_tokens: u32,
140    pub prompt_cache_hit_tokens: u32,
141    pub prompt_cache_miss_tokens: u32,
142    pub total_tokens: u32,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub completion_tokens_details: Option<CompletionTokensDetails>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub prompt_tokens_details: Option<PromptTokensDetails>,
147}
148
149impl Usage {
150    fn new() -> Self {
151        Self {
152            completion_tokens: 0,
153            prompt_tokens: 0,
154            prompt_cache_hit_tokens: 0,
155            prompt_cache_miss_tokens: 0,
156            total_tokens: 0,
157            completion_tokens_details: None,
158            prompt_tokens_details: None,
159        }
160    }
161}
162
163#[derive(Clone, Debug, Serialize, Deserialize, Default)]
164pub struct CompletionTokensDetails {
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub reasoning_tokens: Option<u32>,
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default)]
170pub struct PromptTokensDetails {
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub cached_tokens: Option<u32>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
176pub struct Choice {
177    pub index: usize,
178    pub message: Message,
179    pub logprobs: Option<serde_json::Value>,
180    pub finish_reason: String,
181}
182
183#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
184#[serde(tag = "role", rename_all = "lowercase")]
185pub enum Message {
186    System {
187        content: String,
188        #[serde(skip_serializing_if = "Option::is_none")]
189        name: Option<String>,
190    },
191    User {
192        content: String,
193        #[serde(skip_serializing_if = "Option::is_none")]
194        name: Option<String>,
195    },
196    Assistant {
197        content: String,
198        #[serde(skip_serializing_if = "Option::is_none")]
199        name: Option<String>,
200        #[serde(
201            default,
202            deserialize_with = "json_utils::null_or_vec",
203            skip_serializing_if = "Vec::is_empty"
204        )]
205        tool_calls: Vec<ToolCall>,
206        /// only exists on `deepseek-reasoner` model at time of addition
207        #[serde(skip_serializing_if = "Option::is_none")]
208        reasoning_content: Option<String>,
209    },
210    #[serde(rename = "tool")]
211    ToolResult {
212        tool_call_id: String,
213        content: String,
214    },
215}
216
217impl Message {
218    pub fn system(content: &str) -> Self {
219        Message::System {
220            content: content.to_owned(),
221            name: None,
222        }
223    }
224}
225
226impl From<message::ToolResult> for Message {
227    fn from(tool_result: message::ToolResult) -> Self {
228        let content = match tool_result.content.first() {
229            message::ToolResultContent::Text(text) => text.text,
230            message::ToolResultContent::Image(_) => String::from("[Image]"),
231        };
232
233        Message::ToolResult {
234            tool_call_id: tool_result.id,
235            content,
236        }
237    }
238}
239
240impl From<message::ToolCall> for ToolCall {
241    fn from(tool_call: message::ToolCall) -> Self {
242        Self {
243            id: tool_call.id,
244            // TODO: update index when we have it
245            index: 0,
246            r#type: ToolType::Function,
247            function: Function {
248                name: tool_call.function.name,
249                arguments: tool_call.function.arguments,
250            },
251        }
252    }
253}
254
255impl TryFrom<message::Message> for Vec<Message> {
256    type Error = message::MessageError;
257
258    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
259        match message {
260            message::Message::System { content } => Ok(vec![Message::System {
261                content,
262                name: None,
263            }]),
264            message::Message::User { content } => {
265                // extract tool results
266                let mut messages = vec![];
267
268                let tool_results = content
269                    .clone()
270                    .into_iter()
271                    .filter_map(|content| match content {
272                        message::UserContent::ToolResult(tool_result) => {
273                            Some(Message::from(tool_result))
274                        }
275                        _ => None,
276                    })
277                    .collect::<Vec<_>>();
278
279                messages.extend(tool_results);
280
281                let text_content: String = content
282                    .into_iter()
283                    .filter_map(|content| match content {
284                        message::UserContent::Text(text) => Some(text.text),
285                        message::UserContent::Document(Document {
286                            data:
287                                DocumentSourceKind::Base64(content)
288                                | DocumentSourceKind::String(content),
289                            ..
290                        }) => Some(content),
291                        _ => None,
292                    })
293                    .collect::<Vec<_>>()
294                    .join("\n");
295
296                if !text_content.is_empty() {
297                    messages.push(Message::User {
298                        content: text_content,
299                        name: None,
300                    });
301                }
302
303                Ok(messages)
304            }
305            message::Message::Assistant { content, .. } => {
306                let mut text_content = String::new();
307                let mut reasoning_content = String::new();
308                let mut tool_calls = Vec::new();
309
310                for item in content.iter() {
311                    match item {
312                        message::AssistantContent::Text(text) => {
313                            text_content.push_str(text.text());
314                        }
315                        message::AssistantContent::Reasoning(reasoning) => {
316                            reasoning_content.push_str(&reasoning.display_text());
317                        }
318                        message::AssistantContent::ToolCall(tool_call) => {
319                            tool_calls.push(ToolCall::from(tool_call.clone()));
320                        }
321                        _ => {}
322                    }
323                }
324
325                let reasoning = if reasoning_content.is_empty() {
326                    None
327                } else {
328                    Some(reasoning_content)
329                };
330
331                Ok(vec![Message::Assistant {
332                    content: text_content,
333                    name: None,
334                    tool_calls,
335                    reasoning_content: reasoning,
336                }])
337            }
338        }
339    }
340}
341
342#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
343pub struct ToolCall {
344    pub id: String,
345    pub index: usize,
346    #[serde(default)]
347    pub r#type: ToolType,
348    pub function: Function,
349}
350
351#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
352pub struct Function {
353    pub name: String,
354    #[serde(with = "json_utils::stringified_json")]
355    pub arguments: serde_json::Value,
356}
357
358#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
359#[serde(rename_all = "lowercase")]
360pub enum ToolType {
361    #[default]
362    Function,
363}
364
365#[derive(Clone, Debug, Deserialize, Serialize)]
366pub struct ToolDefinition {
367    pub r#type: String,
368    pub function: completion::ToolDefinition,
369}
370
371impl From<crate::completion::ToolDefinition> for ToolDefinition {
372    fn from(tool: crate::completion::ToolDefinition) -> Self {
373        Self {
374            r#type: "function".into(),
375            function: tool,
376        }
377    }
378}
379
380impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
381    type Error = CompletionError;
382
383    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
384        let choice = response.choices.first().ok_or_else(|| {
385            CompletionError::ResponseError("Response contained no choices".to_owned())
386        })?;
387        let content = match &choice.message {
388            Message::Assistant {
389                content,
390                tool_calls,
391                reasoning_content,
392                ..
393            } => {
394                let mut content = if content.trim().is_empty() {
395                    vec![]
396                } else {
397                    vec![completion::AssistantContent::text(content)]
398                };
399
400                content.extend(
401                    tool_calls
402                        .iter()
403                        .map(|call| {
404                            completion::AssistantContent::tool_call(
405                                &call.id,
406                                &call.function.name,
407                                call.function.arguments.clone(),
408                            )
409                        })
410                        .collect::<Vec<_>>(),
411                );
412
413                if let Some(reasoning_content) = reasoning_content {
414                    content.push(completion::AssistantContent::reasoning(reasoning_content));
415                }
416
417                Ok(content)
418            }
419            _ => Err(CompletionError::ResponseError(
420                "Response did not contain a valid message or tool call".into(),
421            )),
422        }?;
423
424        let choice = OneOrMany::many(content).map_err(|_| {
425            CompletionError::ResponseError(
426                "Response contained no message or tool call (empty)".to_owned(),
427            )
428        })?;
429
430        let usage = completion::Usage {
431            input_tokens: response.usage.prompt_tokens as u64,
432            output_tokens: response.usage.completion_tokens as u64,
433            total_tokens: response.usage.total_tokens as u64,
434            cached_input_tokens: response
435                .usage
436                .prompt_tokens_details
437                .as_ref()
438                .and_then(|d| d.cached_tokens)
439                .map(|c| c as u64)
440                .unwrap_or(0),
441            cache_creation_input_tokens: 0,
442        };
443
444        Ok(completion::CompletionResponse {
445            choice,
446            usage,
447            raw_response: response,
448            message_id: None,
449        })
450    }
451}
452
453#[derive(Debug, Serialize, Deserialize)]
454pub(super) struct DeepseekCompletionRequest {
455    model: String,
456    pub messages: Vec<Message>,
457    #[serde(skip_serializing_if = "Option::is_none")]
458    temperature: Option<f64>,
459    #[serde(skip_serializing_if = "Vec::is_empty")]
460    tools: Vec<ToolDefinition>,
461    #[serde(skip_serializing_if = "Option::is_none")]
462    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
463    #[serde(flatten, skip_serializing_if = "Option::is_none")]
464    pub additional_params: Option<serde_json::Value>,
465}
466
467impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
468    type Error = CompletionError;
469
470    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
471        if req.output_schema.is_some() {
472            tracing::warn!("Structured outputs currently not supported for DeepSeek");
473        }
474        let model = req.model.clone().unwrap_or_else(|| model.to_string());
475        let mut full_history: Vec<Message> = match &req.preamble {
476            Some(preamble) => vec![Message::system(preamble)],
477            None => vec![],
478        };
479
480        if let Some(docs) = req.normalized_documents() {
481            let docs: Vec<Message> = docs.try_into()?;
482            full_history.extend(docs);
483        }
484
485        let chat_history: Vec<Message> = req
486            .chat_history
487            .clone()
488            .into_iter()
489            .map(|message| message.try_into())
490            .collect::<Result<Vec<Vec<Message>>, _>>()?
491            .into_iter()
492            .flatten()
493            .collect();
494
495        full_history.extend(chat_history);
496
497        let tool_choice = req
498            .tool_choice
499            .clone()
500            .map(crate::providers::openrouter::ToolChoice::try_from)
501            .transpose()?;
502
503        Ok(Self {
504            model: model.to_string(),
505            messages: full_history,
506            temperature: req.temperature,
507            tools: req
508                .tools
509                .clone()
510                .into_iter()
511                .map(ToolDefinition::from)
512                .collect::<Vec<_>>(),
513            tool_choice,
514            additional_params: req.additional_params,
515        })
516    }
517}
518
519/// The struct implementing the `CompletionModel` trait
520#[derive(Clone)]
521pub struct CompletionModel<T = reqwest::Client> {
522    pub client: Client<T>,
523    pub model: String,
524}
525
526impl<T> completion::CompletionModel for CompletionModel<T>
527where
528    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
529{
530    type Response = CompletionResponse;
531    type StreamingResponse = StreamingCompletionResponse;
532
533    type Client = Client<T>;
534
535    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
536        Self {
537            client: client.clone(),
538            model: model.into().to_string(),
539        }
540    }
541
542    async fn completion(
543        &self,
544        completion_request: CompletionRequest,
545    ) -> Result<
546        completion::CompletionResponse<CompletionResponse>,
547        crate::completion::CompletionError,
548    > {
549        let span = if tracing::Span::current().is_disabled() {
550            info_span!(
551                target: "rig::completions",
552                "chat",
553                gen_ai.operation.name = "chat",
554                gen_ai.provider.name = "deepseek",
555                gen_ai.request.model = self.model,
556                gen_ai.system_instructions = tracing::field::Empty,
557                gen_ai.response.id = tracing::field::Empty,
558                gen_ai.response.model = tracing::field::Empty,
559                gen_ai.usage.output_tokens = tracing::field::Empty,
560                gen_ai.usage.input_tokens = tracing::field::Empty,
561                gen_ai.usage.cached_tokens = tracing::field::Empty,
562            )
563        } else {
564            tracing::Span::current()
565        };
566
567        span.record("gen_ai.system_instructions", &completion_request.preamble);
568
569        let request =
570            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
571
572        if enabled!(Level::TRACE) {
573            tracing::trace!(target: "rig::completions",
574                "DeepSeek completion request: {}",
575                serde_json::to_string_pretty(&request)?
576            );
577        }
578
579        let body = serde_json::to_vec(&request)?;
580        let req = self
581            .client
582            .post("/chat/completions")?
583            .body(body)
584            .map_err(|e| CompletionError::HttpError(e.into()))?;
585
586        async move {
587            let response = self.client.send::<_, Bytes>(req).await?;
588            let status = response.status();
589            let response_body = response.into_body().into_future().await?.to_vec();
590
591            if status.is_success() {
592                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
593                    ApiResponse::Ok(response) => {
594                        let span = tracing::Span::current();
595                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
596                        span.record(
597                            "gen_ai.usage.output_tokens",
598                            response.usage.completion_tokens,
599                        );
600                        span.record(
601                            "gen_ai.usage.cached_tokens",
602                            response
603                                .usage
604                                .prompt_tokens_details
605                                .as_ref()
606                                .and_then(|d| d.cached_tokens)
607                                .unwrap_or(0),
608                        );
609                        if enabled!(Level::TRACE) {
610                            tracing::trace!(target: "rig::completions",
611                                "DeepSeek completion response: {}",
612                                serde_json::to_string_pretty(&response)?
613                            );
614                        }
615                        response.try_into()
616                    }
617                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
618                }
619            } else {
620                Err(CompletionError::ProviderError(
621                    String::from_utf8_lossy(&response_body).to_string(),
622                ))
623            }
624        }
625        .instrument(span)
626        .await
627    }
628
629    async fn stream(
630        &self,
631        completion_request: CompletionRequest,
632    ) -> Result<
633        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
634        CompletionError,
635    > {
636        let preamble = completion_request.preamble.clone();
637        let mut request =
638            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
639
640        let params = json_utils::merge(
641            request.additional_params.unwrap_or(serde_json::json!({})),
642            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
643        );
644
645        request.additional_params = Some(params);
646
647        if enabled!(Level::TRACE) {
648            tracing::trace!(target: "rig::completions",
649                "DeepSeek streaming completion request: {}",
650                serde_json::to_string_pretty(&request)?
651            );
652        }
653
654        let body = serde_json::to_vec(&request)?;
655
656        let req = self
657            .client
658            .post("/chat/completions")?
659            .body(body)
660            .map_err(|e| CompletionError::HttpError(e.into()))?;
661
662        let span = if tracing::Span::current().is_disabled() {
663            info_span!(
664                target: "rig::completions",
665                "chat_streaming",
666                gen_ai.operation.name = "chat_streaming",
667                gen_ai.provider.name = "deepseek",
668                gen_ai.request.model = self.model,
669                gen_ai.system_instructions = preamble,
670                gen_ai.response.id = tracing::field::Empty,
671                gen_ai.response.model = tracing::field::Empty,
672                gen_ai.usage.output_tokens = tracing::field::Empty,
673                gen_ai.usage.input_tokens = tracing::field::Empty,
674                gen_ai.usage.cached_tokens = tracing::field::Empty,
675            )
676        } else {
677            tracing::Span::current()
678        };
679
680        tracing::Instrument::instrument(
681            send_compatible_streaming_request(self.client.clone(), req),
682            span,
683        )
684        .await
685    }
686}
687
688#[derive(Deserialize, Debug)]
689pub struct StreamingDelta {
690    #[serde(default)]
691    content: Option<String>,
692    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
693    tool_calls: Vec<StreamingToolCall>,
694    reasoning_content: Option<String>,
695}
696
697#[derive(Deserialize, Debug)]
698struct StreamingChoice {
699    delta: StreamingDelta,
700}
701
702#[derive(Deserialize, Debug)]
703struct StreamingCompletionChunk {
704    choices: Vec<StreamingChoice>,
705    usage: Option<Usage>,
706}
707
708#[derive(Clone, Deserialize, Serialize, Debug)]
709pub struct StreamingCompletionResponse {
710    pub usage: Usage,
711}
712
713impl GetTokenUsage for StreamingCompletionResponse {
714    fn token_usage(&self) -> Option<crate::completion::Usage> {
715        let mut usage = crate::completion::Usage::new();
716        usage.input_tokens = self.usage.prompt_tokens as u64;
717        usage.output_tokens = self.usage.completion_tokens as u64;
718        usage.total_tokens = self.usage.total_tokens as u64;
719        usage.cached_input_tokens = self
720            .usage
721            .prompt_tokens_details
722            .as_ref()
723            .and_then(|d| d.cached_tokens)
724            .map(|c| c as u64)
725            .unwrap_or(0);
726
727        Some(usage)
728    }
729}
730
731pub async fn send_compatible_streaming_request<T>(
732    http_client: T,
733    req: Request<Vec<u8>>,
734) -> Result<
735    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
736    CompletionError,
737>
738where
739    T: HttpClientExt + Clone + 'static,
740{
741    let mut event_source = GenericEventSource::new(http_client, req);
742
743    let stream = stream! {
744        let mut final_usage = Usage::new();
745        let mut text_response = String::new();
746        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
747
748        while let Some(event_result) = event_source.next().await {
749            match event_result {
750                Ok(Event::Open) => {
751                    tracing::trace!("SSE connection opened");
752                    continue;
753                }
754                Ok(Event::Message(message)) => {
755                    if message.data.trim().is_empty() || message.data == "[DONE]" {
756                        continue;
757                    }
758
759                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
760                    let Ok(data) = parsed else {
761                        let err = parsed.unwrap_err();
762                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
763                        continue;
764                    };
765
766                    if let Some(choice) = data.choices.first() {
767                        let delta = &choice.delta;
768
769                        if !delta.tool_calls.is_empty() {
770                            for tool_call in &delta.tool_calls {
771                                let function = &tool_call.function;
772
773                                // Start of tool call
774                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
775                                    && empty_or_none(&function.arguments)
776                                {
777                                    let id = tool_call.id.clone().unwrap_or_default();
778                                    let name = function.name.clone().unwrap();
779                                    calls.insert(tool_call.index, (id, name, String::new()));
780                                }
781                                // Continuation of tool call
782                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
783                                    && let Some(arguments) = &function.arguments
784                                    && !arguments.is_empty()
785                                {
786                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
787                                        let combined = format!("{}{}", existing_args, arguments);
788                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
789                                    } else {
790                                        tracing::debug!("Partial tool call received but tool call was never started.");
791                                    }
792                                }
793                                // Complete tool call
794                                else {
795                                    let id = tool_call.id.clone().unwrap_or_default();
796                                    let name = function.name.clone().unwrap_or_default();
797                                    let arguments_str = function.arguments.clone().unwrap_or_default();
798
799                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
800                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
801                                        continue;
802                                    };
803
804                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
805                                        crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
806                                    ));
807                                }
808                            }
809                        }
810
811                        // DeepSeek-specific reasoning stream
812                        if let Some(content) = &delta.reasoning_content {
813                            yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
814                                id: None,
815                                reasoning: content.to_string()
816                            });
817                        }
818
819                        if let Some(content) = &delta.content {
820                            text_response += content;
821                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
822                        }
823                    }
824
825                    if let Some(usage) = data.usage {
826                        final_usage = usage.clone();
827                    }
828                }
829                Err(crate::http_client::Error::StreamEnded) => {
830                    break;
831                }
832                Err(err) => {
833                    tracing::error!(?err, "SSE error");
834                    yield Err(CompletionError::ResponseError(err.to_string()));
835                    break;
836                }
837            }
838        }
839
840        event_source.close();
841
842        let mut tool_calls = Vec::new();
843        // Flush accumulated tool calls
844        for (index, (id, name, arguments)) in calls {
845            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
846                continue;
847            };
848
849            tool_calls.push(ToolCall {
850                id: id.clone(),
851                index,
852                r#type: ToolType::Function,
853                function: Function {
854                    name: name.clone(),
855                    arguments: arguments_json.clone()
856                }
857            });
858            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
859                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
860            ));
861        }
862
863        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
864            StreamingCompletionResponse { usage: final_usage.clone() }
865        ));
866    };
867
868    Ok(crate::streaming::StreamingCompletionResponse::stream(
869        Box::pin(stream),
870    ))
871}
872
873// ================================================================
874// DeepSeek Completion API
875// ================================================================
876pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
877pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
878
879// Tests
880#[cfg(test)]
881mod tests {
882    use super::*;
883
884    #[test]
885    fn test_deserialize_vec_choice() {
886        let data = r#"[{
887            "finish_reason": "stop",
888            "index": 0,
889            "logprobs": null,
890            "message":{"role":"assistant","content":"Hello, world!"}
891            }]"#;
892
893        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
894        assert_eq!(choices.len(), 1);
895        match &choices.first().unwrap().message {
896            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
897            _ => panic!("Expected assistant message"),
898        }
899    }
900
901    #[test]
902    fn test_deserialize_deepseek_response() {
903        let data = r#"{
904            "choices":[{
905                "finish_reason": "stop",
906                "index": 0,
907                "logprobs": null,
908                "message":{"role":"assistant","content":"Hello, world!"}
909            }],
910            "usage": {
911                "completion_tokens": 0,
912                "prompt_tokens": 0,
913                "prompt_cache_hit_tokens": 0,
914                "prompt_cache_miss_tokens": 0,
915                "total_tokens": 0
916            }
917        }"#;
918
919        let jd = &mut serde_json::Deserializer::from_str(data);
920        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
921        match result {
922            Ok(response) => match &response.choices.first().unwrap().message {
923                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
924                _ => panic!("Expected assistant message"),
925            },
926            Err(err) => {
927                panic!("Deserialization error at {}: {}", err.path(), err);
928            }
929        }
930    }
931
932    #[test]
933    fn test_deserialize_example_response() {
934        let data = r#"
935        {
936            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
937            "object": "chat.completion",
938            "created": 0,
939            "model": "deepseek-chat",
940            "choices": [
941                {
942                    "index": 0,
943                    "message": {
944                        "role": "assistant",
945                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
946                    },
947                    "logprobs": null,
948                    "finish_reason": "stop"
949                }
950            ],
951            "usage": {
952                "prompt_tokens": 13,
953                "completion_tokens": 32,
954                "total_tokens": 45,
955                "prompt_tokens_details": {
956                    "cached_tokens": 0
957                },
958                "prompt_cache_hit_tokens": 0,
959                "prompt_cache_miss_tokens": 13
960            },
961            "system_fingerprint": "fp_4b6881f2c5"
962        }
963        "#;
964        let jd = &mut serde_json::Deserializer::from_str(data);
965        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
966
967        match result {
968            Ok(response) => match &response.choices.first().unwrap().message {
969                Message::Assistant { content, .. } => assert_eq!(
970                    content,
971                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
972                ),
973                _ => panic!("Expected assistant message"),
974            },
975            Err(err) => {
976                panic!("Deserialization error at {}: {}", err.path(), err);
977            }
978        }
979    }
980
981    #[test]
982    fn test_serialize_deserialize_tool_call_message() {
983        let tool_call_choice_json = r#"
984            {
985              "finish_reason": "tool_calls",
986              "index": 0,
987              "logprobs": null,
988              "message": {
989                "content": "",
990                "role": "assistant",
991                "tool_calls": [
992                  {
993                    "function": {
994                      "arguments": "{\"x\":2,\"y\":5}",
995                      "name": "subtract"
996                    },
997                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
998                    "index": 0,
999                    "type": "function"
1000                  }
1001                ]
1002              }
1003            }
1004        "#;
1005
1006        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1007
1008        let expected_choice: Choice = Choice {
1009            finish_reason: "tool_calls".to_string(),
1010            index: 0,
1011            logprobs: None,
1012            message: Message::Assistant {
1013                content: "".to_string(),
1014                name: None,
1015                tool_calls: vec![ToolCall {
1016                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1017                    function: Function {
1018                        name: "subtract".to_string(),
1019                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1020                    },
1021                    index: 0,
1022                    r#type: ToolType::Function,
1023                }],
1024                reasoning_content: None,
1025            },
1026        };
1027
1028        assert_eq!(choice, expected_choice);
1029    }
1030    #[test]
1031    fn test_user_message_multiple_text_items_merged() {
1032        use crate::completion::message::{Message as RigMessage, UserContent};
1033
1034        let rig_msg = RigMessage::User {
1035            content: OneOrMany::many(vec![
1036                UserContent::text("first part"),
1037                UserContent::text("second part"),
1038            ])
1039            .expect("content should not be empty"),
1040        };
1041
1042        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1043
1044        let user_messages: Vec<&Message> = messages
1045            .iter()
1046            .filter(|m| matches!(m, Message::User { .. }))
1047            .collect();
1048
1049        assert_eq!(
1050            user_messages.len(),
1051            1,
1052            "multiple text items should produce a single user message"
1053        );
1054        match &user_messages[0] {
1055            Message::User { content, .. } => {
1056                assert_eq!(content, "first part\nsecond part");
1057            }
1058            _ => unreachable!(),
1059        }
1060    }
1061
1062    #[test]
1063    fn test_assistant_message_with_reasoning_and_tool_calls() {
1064        use crate::completion::message::{AssistantContent, Message as RigMessage};
1065
1066        let rig_msg = RigMessage::Assistant {
1067            id: None,
1068            content: OneOrMany::many(vec![
1069                AssistantContent::reasoning("thinking about the problem"),
1070                AssistantContent::text("I'll call the tool"),
1071                AssistantContent::tool_call(
1072                    "call_1",
1073                    "subtract",
1074                    serde_json::json!({"x": 2, "y": 5}),
1075                ),
1076            ])
1077            .expect("content should not be empty"),
1078        };
1079
1080        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1081
1082        assert_eq!(messages.len(), 1, "should produce exactly one message");
1083        match &messages[0] {
1084            Message::Assistant {
1085                content,
1086                tool_calls,
1087                reasoning_content,
1088                ..
1089            } => {
1090                assert_eq!(content, "I'll call the tool");
1091                assert_eq!(
1092                    reasoning_content.as_deref(),
1093                    Some("thinking about the problem")
1094                );
1095                assert_eq!(tool_calls.len(), 1);
1096                assert_eq!(tool_calls[0].function.name, "subtract");
1097            }
1098            _ => panic!("Expected assistant message"),
1099        }
1100    }
1101
1102    #[test]
1103    fn test_assistant_message_without_reasoning() {
1104        use crate::completion::message::{AssistantContent, Message as RigMessage};
1105
1106        let rig_msg = RigMessage::Assistant {
1107            id: None,
1108            content: OneOrMany::many(vec![
1109                AssistantContent::text("calling tool"),
1110                AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1111            ])
1112            .expect("content should not be empty"),
1113        };
1114
1115        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1116
1117        assert_eq!(messages.len(), 1);
1118        match &messages[0] {
1119            Message::Assistant {
1120                reasoning_content,
1121                tool_calls,
1122                ..
1123            } => {
1124                assert!(reasoning_content.is_none());
1125                assert_eq!(tool_calls.len(), 1);
1126            }
1127            _ => panic!("Expected assistant message"),
1128        }
1129    }
1130
1131    #[test]
1132    fn test_client_initialization() {
1133        let _client =
1134            crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1135        let _client_from_builder = crate::providers::deepseek::Client::builder()
1136            .api_key("dummy-key")
1137            .build()
1138            .expect("Client::builder() failed");
1139    }
1140}