Skip to main content

rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51    const VERIFY_PATH: &'static str = "/user/balance";
52}
53
54impl<H> Capabilities<H> for DeepSeekExt {
55    type Completion = Capable<CompletionModel<H>>;
56    type Embeddings = Nothing;
57    type Transcription = Nothing;
58    type ModelListing = Nothing;
59    #[cfg(feature = "image")]
60    type ImageGeneration = Nothing;
61    #[cfg(feature = "audio")]
62    type AudioGeneration = Nothing;
63}
64
65impl DebugExt for DeepSeekExt {}
66
67impl ProviderBuilder for DeepSeekExtBuilder {
68    type Extension<H>
69        = DeepSeekExt
70    where
71        H: HttpClientExt;
72    type ApiKey = DeepSeekApiKey;
73
74    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
75
76    fn build<H>(
77        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
78    ) -> http_client::Result<Self::Extension<H>>
79    where
80        H: HttpClientExt,
81    {
82        Ok(DeepSeekExt)
83    }
84}
85
86pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
87pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
88
89impl ProviderClient for Client {
90    type Input = DeepSeekApiKey;
91
92    // If you prefer the environment variable approach:
93    fn from_env() -> Self {
94        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
95        let mut client_builder = Self::builder();
96        client_builder.headers_mut().insert(
97            http::header::CONTENT_TYPE,
98            http::HeaderValue::from_static("application/json"),
99        );
100        let client_builder = client_builder.api_key(&api_key);
101        client_builder.build().unwrap()
102    }
103
104    fn from_val(input: Self::Input) -> Self {
105        Self::new(input).unwrap()
106    }
107}
108
109#[derive(Debug, Deserialize)]
110struct ApiErrorResponse {
111    message: String,
112}
113
114#[derive(Debug, Deserialize)]
115#[serde(untagged)]
116enum ApiResponse<T> {
117    Ok(T),
118    Err(ApiErrorResponse),
119}
120
121impl From<ApiErrorResponse> for CompletionError {
122    fn from(err: ApiErrorResponse) -> Self {
123        CompletionError::ProviderError(err.message)
124    }
125}
126
127/// The response shape from the DeepSeek API
128#[derive(Clone, Debug, Serialize, Deserialize)]
129pub struct CompletionResponse {
130    // We'll match the JSON:
131    pub choices: Vec<Choice>,
132    pub usage: Usage,
133    // you may want other fields
134}
135
136#[derive(Clone, Debug, Serialize, Deserialize, Default)]
137pub struct Usage {
138    pub completion_tokens: u32,
139    pub prompt_tokens: u32,
140    pub prompt_cache_hit_tokens: u32,
141    pub prompt_cache_miss_tokens: u32,
142    pub total_tokens: u32,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub completion_tokens_details: Option<CompletionTokensDetails>,
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub prompt_tokens_details: Option<PromptTokensDetails>,
147}
148
149impl Usage {
150    fn new() -> Self {
151        Self {
152            completion_tokens: 0,
153            prompt_tokens: 0,
154            prompt_cache_hit_tokens: 0,
155            prompt_cache_miss_tokens: 0,
156            total_tokens: 0,
157            completion_tokens_details: None,
158            prompt_tokens_details: None,
159        }
160    }
161}
162
163#[derive(Clone, Debug, Serialize, Deserialize, Default)]
164pub struct CompletionTokensDetails {
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub reasoning_tokens: Option<u32>,
167}
168
169#[derive(Clone, Debug, Serialize, Deserialize, Default)]
170pub struct PromptTokensDetails {
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub cached_tokens: Option<u32>,
173}
174
175#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
176pub struct Choice {
177    pub index: usize,
178    pub message: Message,
179    pub logprobs: Option<serde_json::Value>,
180    pub finish_reason: String,
181}
182
183#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
184#[serde(tag = "role", rename_all = "lowercase")]
185pub enum Message {
186    System {
187        content: String,
188        #[serde(skip_serializing_if = "Option::is_none")]
189        name: Option<String>,
190    },
191    User {
192        content: String,
193        #[serde(skip_serializing_if = "Option::is_none")]
194        name: Option<String>,
195    },
196    Assistant {
197        content: String,
198        #[serde(skip_serializing_if = "Option::is_none")]
199        name: Option<String>,
200        #[serde(
201            default,
202            deserialize_with = "json_utils::null_or_vec",
203            skip_serializing_if = "Vec::is_empty"
204        )]
205        tool_calls: Vec<ToolCall>,
206        /// only exists on `deepseek-reasoner` model at time of addition
207        #[serde(skip_serializing_if = "Option::is_none")]
208        reasoning_content: Option<String>,
209    },
210    #[serde(rename = "tool")]
211    ToolResult {
212        tool_call_id: String,
213        content: String,
214    },
215}
216
217impl Message {
218    pub fn system(content: &str) -> Self {
219        Message::System {
220            content: content.to_owned(),
221            name: None,
222        }
223    }
224}
225
226impl From<message::ToolResult> for Message {
227    fn from(tool_result: message::ToolResult) -> Self {
228        let content = match tool_result.content.first() {
229            message::ToolResultContent::Text(text) => text.text,
230            message::ToolResultContent::Image(_) => String::from("[Image]"),
231        };
232
233        Message::ToolResult {
234            tool_call_id: tool_result.id,
235            content,
236        }
237    }
238}
239
240impl From<message::ToolCall> for ToolCall {
241    fn from(tool_call: message::ToolCall) -> Self {
242        Self {
243            id: tool_call.id,
244            // TODO: update index when we have it
245            index: 0,
246            r#type: ToolType::Function,
247            function: Function {
248                name: tool_call.function.name,
249                arguments: tool_call.function.arguments,
250            },
251        }
252    }
253}
254
255impl TryFrom<message::Message> for Vec<Message> {
256    type Error = message::MessageError;
257
258    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
259        match message {
260            message::Message::User { content } => {
261                // extract tool results
262                let mut messages = vec![];
263
264                let tool_results = content
265                    .clone()
266                    .into_iter()
267                    .filter_map(|content| match content {
268                        message::UserContent::ToolResult(tool_result) => {
269                            Some(Message::from(tool_result))
270                        }
271                        _ => None,
272                    })
273                    .collect::<Vec<_>>();
274
275                messages.extend(tool_results);
276
277                let text_content: String = content
278                    .into_iter()
279                    .filter_map(|content| match content {
280                        message::UserContent::Text(text) => Some(text.text),
281                        message::UserContent::Document(Document {
282                            data:
283                                DocumentSourceKind::Base64(content)
284                                | DocumentSourceKind::String(content),
285                            ..
286                        }) => Some(content),
287                        _ => None,
288                    })
289                    .collect::<Vec<_>>()
290                    .join("\n");
291
292                if !text_content.is_empty() {
293                    messages.push(Message::User {
294                        content: text_content,
295                        name: None,
296                    });
297                }
298
299                Ok(messages)
300            }
301            message::Message::Assistant { content, .. } => {
302                let mut text_content = String::new();
303                let mut reasoning_content = String::new();
304                let mut tool_calls = Vec::new();
305
306                for item in content.iter() {
307                    match item {
308                        message::AssistantContent::Text(text) => {
309                            text_content.push_str(text.text());
310                        }
311                        message::AssistantContent::Reasoning(reasoning) => {
312                            reasoning_content.push_str(&reasoning.display_text());
313                        }
314                        message::AssistantContent::ToolCall(tool_call) => {
315                            tool_calls.push(ToolCall::from(tool_call.clone()));
316                        }
317                        _ => {}
318                    }
319                }
320
321                let reasoning = if reasoning_content.is_empty() {
322                    None
323                } else {
324                    Some(reasoning_content)
325                };
326
327                Ok(vec![Message::Assistant {
328                    content: text_content,
329                    name: None,
330                    tool_calls,
331                    reasoning_content: reasoning,
332                }])
333            }
334        }
335    }
336}
337
338#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
339pub struct ToolCall {
340    pub id: String,
341    pub index: usize,
342    #[serde(default)]
343    pub r#type: ToolType,
344    pub function: Function,
345}
346
347#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
348pub struct Function {
349    pub name: String,
350    #[serde(with = "json_utils::stringified_json")]
351    pub arguments: serde_json::Value,
352}
353
354#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
355#[serde(rename_all = "lowercase")]
356pub enum ToolType {
357    #[default]
358    Function,
359}
360
361#[derive(Clone, Debug, Deserialize, Serialize)]
362pub struct ToolDefinition {
363    pub r#type: String,
364    pub function: completion::ToolDefinition,
365}
366
367impl From<crate::completion::ToolDefinition> for ToolDefinition {
368    fn from(tool: crate::completion::ToolDefinition) -> Self {
369        Self {
370            r#type: "function".into(),
371            function: tool,
372        }
373    }
374}
375
376impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
377    type Error = CompletionError;
378
379    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
380        let choice = response.choices.first().ok_or_else(|| {
381            CompletionError::ResponseError("Response contained no choices".to_owned())
382        })?;
383        let content = match &choice.message {
384            Message::Assistant {
385                content,
386                tool_calls,
387                reasoning_content,
388                ..
389            } => {
390                let mut content = if content.trim().is_empty() {
391                    vec![]
392                } else {
393                    vec![completion::AssistantContent::text(content)]
394                };
395
396                content.extend(
397                    tool_calls
398                        .iter()
399                        .map(|call| {
400                            completion::AssistantContent::tool_call(
401                                &call.id,
402                                &call.function.name,
403                                call.function.arguments.clone(),
404                            )
405                        })
406                        .collect::<Vec<_>>(),
407                );
408
409                if let Some(reasoning_content) = reasoning_content {
410                    content.push(completion::AssistantContent::reasoning(reasoning_content));
411                }
412
413                Ok(content)
414            }
415            _ => Err(CompletionError::ResponseError(
416                "Response did not contain a valid message or tool call".into(),
417            )),
418        }?;
419
420        let choice = OneOrMany::many(content).map_err(|_| {
421            CompletionError::ResponseError(
422                "Response contained no message or tool call (empty)".to_owned(),
423            )
424        })?;
425
426        let usage = completion::Usage {
427            input_tokens: response.usage.prompt_tokens as u64,
428            output_tokens: response.usage.completion_tokens as u64,
429            total_tokens: response.usage.total_tokens as u64,
430            cached_input_tokens: response
431                .usage
432                .prompt_tokens_details
433                .as_ref()
434                .and_then(|d| d.cached_tokens)
435                .map(|c| c as u64)
436                .unwrap_or(0),
437        };
438
439        Ok(completion::CompletionResponse {
440            choice,
441            usage,
442            raw_response: response,
443            message_id: None,
444        })
445    }
446}
447
448#[derive(Debug, Serialize, Deserialize)]
449pub(super) struct DeepseekCompletionRequest {
450    model: String,
451    pub messages: Vec<Message>,
452    #[serde(skip_serializing_if = "Option::is_none")]
453    temperature: Option<f64>,
454    #[serde(skip_serializing_if = "Vec::is_empty")]
455    tools: Vec<ToolDefinition>,
456    #[serde(skip_serializing_if = "Option::is_none")]
457    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
458    #[serde(flatten, skip_serializing_if = "Option::is_none")]
459    pub additional_params: Option<serde_json::Value>,
460}
461
462impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
463    type Error = CompletionError;
464
465    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
466        if req.output_schema.is_some() {
467            tracing::warn!("Structured outputs currently not supported for DeepSeek");
468        }
469        let model = req.model.clone().unwrap_or_else(|| model.to_string());
470        let mut full_history: Vec<Message> = match &req.preamble {
471            Some(preamble) => vec![Message::system(preamble)],
472            None => vec![],
473        };
474
475        if let Some(docs) = req.normalized_documents() {
476            let docs: Vec<Message> = docs.try_into()?;
477            full_history.extend(docs);
478        }
479
480        let chat_history: Vec<Message> = req
481            .chat_history
482            .clone()
483            .into_iter()
484            .map(|message| message.try_into())
485            .collect::<Result<Vec<Vec<Message>>, _>>()?
486            .into_iter()
487            .flatten()
488            .collect();
489
490        full_history.extend(chat_history);
491
492        let tool_choice = req
493            .tool_choice
494            .clone()
495            .map(crate::providers::openrouter::ToolChoice::try_from)
496            .transpose()?;
497
498        Ok(Self {
499            model: model.to_string(),
500            messages: full_history,
501            temperature: req.temperature,
502            tools: req
503                .tools
504                .clone()
505                .into_iter()
506                .map(ToolDefinition::from)
507                .collect::<Vec<_>>(),
508            tool_choice,
509            additional_params: req.additional_params,
510        })
511    }
512}
513
514/// The struct implementing the `CompletionModel` trait
515#[derive(Clone)]
516pub struct CompletionModel<T = reqwest::Client> {
517    pub client: Client<T>,
518    pub model: String,
519}
520
521impl<T> completion::CompletionModel for CompletionModel<T>
522where
523    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
524{
525    type Response = CompletionResponse;
526    type StreamingResponse = StreamingCompletionResponse;
527
528    type Client = Client<T>;
529
530    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
531        Self {
532            client: client.clone(),
533            model: model.into().to_string(),
534        }
535    }
536
537    async fn completion(
538        &self,
539        completion_request: CompletionRequest,
540    ) -> Result<
541        completion::CompletionResponse<CompletionResponse>,
542        crate::completion::CompletionError,
543    > {
544        let span = if tracing::Span::current().is_disabled() {
545            info_span!(
546                target: "rig::completions",
547                "chat",
548                gen_ai.operation.name = "chat",
549                gen_ai.provider.name = "deepseek",
550                gen_ai.request.model = self.model,
551                gen_ai.system_instructions = tracing::field::Empty,
552                gen_ai.response.id = tracing::field::Empty,
553                gen_ai.response.model = tracing::field::Empty,
554                gen_ai.usage.output_tokens = tracing::field::Empty,
555                gen_ai.usage.input_tokens = tracing::field::Empty,
556            )
557        } else {
558            tracing::Span::current()
559        };
560
561        span.record("gen_ai.system_instructions", &completion_request.preamble);
562
563        let request =
564            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
565
566        if enabled!(Level::TRACE) {
567            tracing::trace!(target: "rig::completions",
568                "DeepSeek completion request: {}",
569                serde_json::to_string_pretty(&request)?
570            );
571        }
572
573        let body = serde_json::to_vec(&request)?;
574        let req = self
575            .client
576            .post("/chat/completions")?
577            .body(body)
578            .map_err(|e| CompletionError::HttpError(e.into()))?;
579
580        async move {
581            let response = self.client.send::<_, Bytes>(req).await?;
582            let status = response.status();
583            let response_body = response.into_body().into_future().await?.to_vec();
584
585            if status.is_success() {
586                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
587                    ApiResponse::Ok(response) => {
588                        let span = tracing::Span::current();
589                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
590                        span.record(
591                            "gen_ai.usage.output_tokens",
592                            response.usage.completion_tokens,
593                        );
594                        if enabled!(Level::TRACE) {
595                            tracing::trace!(target: "rig::completions",
596                                "DeepSeek completion response: {}",
597                                serde_json::to_string_pretty(&response)?
598                            );
599                        }
600                        response.try_into()
601                    }
602                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
603                }
604            } else {
605                Err(CompletionError::ProviderError(
606                    String::from_utf8_lossy(&response_body).to_string(),
607                ))
608            }
609        }
610        .instrument(span)
611        .await
612    }
613
614    async fn stream(
615        &self,
616        completion_request: CompletionRequest,
617    ) -> Result<
618        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
619        CompletionError,
620    > {
621        let preamble = completion_request.preamble.clone();
622        let mut request =
623            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
624
625        let params = json_utils::merge(
626            request.additional_params.unwrap_or(serde_json::json!({})),
627            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
628        );
629
630        request.additional_params = Some(params);
631
632        if enabled!(Level::TRACE) {
633            tracing::trace!(target: "rig::completions",
634                "DeepSeek streaming completion request: {}",
635                serde_json::to_string_pretty(&request)?
636            );
637        }
638
639        let body = serde_json::to_vec(&request)?;
640
641        let req = self
642            .client
643            .post("/chat/completions")?
644            .body(body)
645            .map_err(|e| CompletionError::HttpError(e.into()))?;
646
647        let span = if tracing::Span::current().is_disabled() {
648            info_span!(
649                target: "rig::completions",
650                "chat_streaming",
651                gen_ai.operation.name = "chat_streaming",
652                gen_ai.provider.name = "deepseek",
653                gen_ai.request.model = self.model,
654                gen_ai.system_instructions = preamble,
655                gen_ai.response.id = tracing::field::Empty,
656                gen_ai.response.model = tracing::field::Empty,
657                gen_ai.usage.output_tokens = tracing::field::Empty,
658                gen_ai.usage.input_tokens = tracing::field::Empty,
659            )
660        } else {
661            tracing::Span::current()
662        };
663
664        tracing::Instrument::instrument(
665            send_compatible_streaming_request(self.client.clone(), req),
666            span,
667        )
668        .await
669    }
670}
671
672#[derive(Deserialize, Debug)]
673pub struct StreamingDelta {
674    #[serde(default)]
675    content: Option<String>,
676    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
677    tool_calls: Vec<StreamingToolCall>,
678    reasoning_content: Option<String>,
679}
680
681#[derive(Deserialize, Debug)]
682struct StreamingChoice {
683    delta: StreamingDelta,
684}
685
686#[derive(Deserialize, Debug)]
687struct StreamingCompletionChunk {
688    choices: Vec<StreamingChoice>,
689    usage: Option<Usage>,
690}
691
692#[derive(Clone, Deserialize, Serialize, Debug)]
693pub struct StreamingCompletionResponse {
694    pub usage: Usage,
695}
696
697impl GetTokenUsage for StreamingCompletionResponse {
698    fn token_usage(&self) -> Option<crate::completion::Usage> {
699        let mut usage = crate::completion::Usage::new();
700        usage.input_tokens = self.usage.prompt_tokens as u64;
701        usage.output_tokens = self.usage.completion_tokens as u64;
702        usage.total_tokens = self.usage.total_tokens as u64;
703        usage.cached_input_tokens = self
704            .usage
705            .prompt_tokens_details
706            .as_ref()
707            .and_then(|d| d.cached_tokens)
708            .map(|c| c as u64)
709            .unwrap_or(0);
710
711        Some(usage)
712    }
713}
714
715pub async fn send_compatible_streaming_request<T>(
716    http_client: T,
717    req: Request<Vec<u8>>,
718) -> Result<
719    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
720    CompletionError,
721>
722where
723    T: HttpClientExt + Clone + 'static,
724{
725    let mut event_source = GenericEventSource::new(http_client, req);
726
727    let stream = stream! {
728        let mut final_usage = Usage::new();
729        let mut text_response = String::new();
730        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
731
732        while let Some(event_result) = event_source.next().await {
733            match event_result {
734                Ok(Event::Open) => {
735                    tracing::trace!("SSE connection opened");
736                    continue;
737                }
738                Ok(Event::Message(message)) => {
739                    if message.data.trim().is_empty() || message.data == "[DONE]" {
740                        continue;
741                    }
742
743                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
744                    let Ok(data) = parsed else {
745                        let err = parsed.unwrap_err();
746                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
747                        continue;
748                    };
749
750                    if let Some(choice) = data.choices.first() {
751                        let delta = &choice.delta;
752
753                        if !delta.tool_calls.is_empty() {
754                            for tool_call in &delta.tool_calls {
755                                let function = &tool_call.function;
756
757                                // Start of tool call
758                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
759                                    && empty_or_none(&function.arguments)
760                                {
761                                    let id = tool_call.id.clone().unwrap_or_default();
762                                    let name = function.name.clone().unwrap();
763                                    calls.insert(tool_call.index, (id, name, String::new()));
764                                }
765                                // Continuation of tool call
766                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
767                                    && let Some(arguments) = &function.arguments
768                                    && !arguments.is_empty()
769                                {
770                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
771                                        let combined = format!("{}{}", existing_args, arguments);
772                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
773                                    } else {
774                                        tracing::debug!("Partial tool call received but tool call was never started.");
775                                    }
776                                }
777                                // Complete tool call
778                                else {
779                                    let id = tool_call.id.clone().unwrap_or_default();
780                                    let name = function.name.clone().unwrap_or_default();
781                                    let arguments_str = function.arguments.clone().unwrap_or_default();
782
783                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
784                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
785                                        continue;
786                                    };
787
788                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
789                                        crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
790                                    ));
791                                }
792                            }
793                        }
794
795                        // DeepSeek-specific reasoning stream
796                        if let Some(content) = &delta.reasoning_content {
797                            yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
798                                id: None,
799                                reasoning: content.to_string()
800                            });
801                        }
802
803                        if let Some(content) = &delta.content {
804                            text_response += content;
805                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
806                        }
807                    }
808
809                    if let Some(usage) = data.usage {
810                        final_usage = usage.clone();
811                    }
812                }
813                Err(crate::http_client::Error::StreamEnded) => {
814                    break;
815                }
816                Err(err) => {
817                    tracing::error!(?err, "SSE error");
818                    yield Err(CompletionError::ResponseError(err.to_string()));
819                    break;
820                }
821            }
822        }
823
824        event_source.close();
825
826        let mut tool_calls = Vec::new();
827        // Flush accumulated tool calls
828        for (index, (id, name, arguments)) in calls {
829            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
830                continue;
831            };
832
833            tool_calls.push(ToolCall {
834                id: id.clone(),
835                index,
836                r#type: ToolType::Function,
837                function: Function {
838                    name: name.clone(),
839                    arguments: arguments_json.clone()
840                }
841            });
842            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
843                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
844            ));
845        }
846
847        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
848            StreamingCompletionResponse { usage: final_usage.clone() }
849        ));
850    };
851
852    Ok(crate::streaming::StreamingCompletionResponse::stream(
853        Box::pin(stream),
854    ))
855}
856
857// ================================================================
858// DeepSeek Completion API
859// ================================================================
860pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
861pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
862
863// Tests
864#[cfg(test)]
865mod tests {
866    use super::*;
867
868    #[test]
869    fn test_deserialize_vec_choice() {
870        let data = r#"[{
871            "finish_reason": "stop",
872            "index": 0,
873            "logprobs": null,
874            "message":{"role":"assistant","content":"Hello, world!"}
875            }]"#;
876
877        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
878        assert_eq!(choices.len(), 1);
879        match &choices.first().unwrap().message {
880            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
881            _ => panic!("Expected assistant message"),
882        }
883    }
884
885    #[test]
886    fn test_deserialize_deepseek_response() {
887        let data = r#"{
888            "choices":[{
889                "finish_reason": "stop",
890                "index": 0,
891                "logprobs": null,
892                "message":{"role":"assistant","content":"Hello, world!"}
893            }],
894            "usage": {
895                "completion_tokens": 0,
896                "prompt_tokens": 0,
897                "prompt_cache_hit_tokens": 0,
898                "prompt_cache_miss_tokens": 0,
899                "total_tokens": 0
900            }
901        }"#;
902
903        let jd = &mut serde_json::Deserializer::from_str(data);
904        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
905        match result {
906            Ok(response) => match &response.choices.first().unwrap().message {
907                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
908                _ => panic!("Expected assistant message"),
909            },
910            Err(err) => {
911                panic!("Deserialization error at {}: {}", err.path(), err);
912            }
913        }
914    }
915
916    #[test]
917    fn test_deserialize_example_response() {
918        let data = r#"
919        {
920            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
921            "object": "chat.completion",
922            "created": 0,
923            "model": "deepseek-chat",
924            "choices": [
925                {
926                    "index": 0,
927                    "message": {
928                        "role": "assistant",
929                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
930                    },
931                    "logprobs": null,
932                    "finish_reason": "stop"
933                }
934            ],
935            "usage": {
936                "prompt_tokens": 13,
937                "completion_tokens": 32,
938                "total_tokens": 45,
939                "prompt_tokens_details": {
940                    "cached_tokens": 0
941                },
942                "prompt_cache_hit_tokens": 0,
943                "prompt_cache_miss_tokens": 13
944            },
945            "system_fingerprint": "fp_4b6881f2c5"
946        }
947        "#;
948        let jd = &mut serde_json::Deserializer::from_str(data);
949        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
950
951        match result {
952            Ok(response) => match &response.choices.first().unwrap().message {
953                Message::Assistant { content, .. } => assert_eq!(
954                    content,
955                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
956                ),
957                _ => panic!("Expected assistant message"),
958            },
959            Err(err) => {
960                panic!("Deserialization error at {}: {}", err.path(), err);
961            }
962        }
963    }
964
965    #[test]
966    fn test_serialize_deserialize_tool_call_message() {
967        let tool_call_choice_json = r#"
968            {
969              "finish_reason": "tool_calls",
970              "index": 0,
971              "logprobs": null,
972              "message": {
973                "content": "",
974                "role": "assistant",
975                "tool_calls": [
976                  {
977                    "function": {
978                      "arguments": "{\"x\":2,\"y\":5}",
979                      "name": "subtract"
980                    },
981                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
982                    "index": 0,
983                    "type": "function"
984                  }
985                ]
986              }
987            }
988        "#;
989
990        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
991
992        let expected_choice: Choice = Choice {
993            finish_reason: "tool_calls".to_string(),
994            index: 0,
995            logprobs: None,
996            message: Message::Assistant {
997                content: "".to_string(),
998                name: None,
999                tool_calls: vec![ToolCall {
1000                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1001                    function: Function {
1002                        name: "subtract".to_string(),
1003                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1004                    },
1005                    index: 0,
1006                    r#type: ToolType::Function,
1007                }],
1008                reasoning_content: None,
1009            },
1010        };
1011
1012        assert_eq!(choice, expected_choice);
1013    }
1014    #[test]
1015    fn test_user_message_multiple_text_items_merged() {
1016        use crate::completion::message::{Message as RigMessage, UserContent};
1017
1018        let rig_msg = RigMessage::User {
1019            content: OneOrMany::many(vec![
1020                UserContent::text("first part"),
1021                UserContent::text("second part"),
1022            ])
1023            .expect("content should not be empty"),
1024        };
1025
1026        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1027
1028        let user_messages: Vec<&Message> = messages
1029            .iter()
1030            .filter(|m| matches!(m, Message::User { .. }))
1031            .collect();
1032
1033        assert_eq!(
1034            user_messages.len(),
1035            1,
1036            "multiple text items should produce a single user message"
1037        );
1038        match &user_messages[0] {
1039            Message::User { content, .. } => {
1040                assert_eq!(content, "first part\nsecond part");
1041            }
1042            _ => unreachable!(),
1043        }
1044    }
1045
1046    #[test]
1047    fn test_assistant_message_with_reasoning_and_tool_calls() {
1048        use crate::completion::message::{AssistantContent, Message as RigMessage};
1049
1050        let rig_msg = RigMessage::Assistant {
1051            id: None,
1052            content: OneOrMany::many(vec![
1053                AssistantContent::reasoning("thinking about the problem"),
1054                AssistantContent::text("I'll call the tool"),
1055                AssistantContent::tool_call(
1056                    "call_1",
1057                    "subtract",
1058                    serde_json::json!({"x": 2, "y": 5}),
1059                ),
1060            ])
1061            .expect("content should not be empty"),
1062        };
1063
1064        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1065
1066        assert_eq!(messages.len(), 1, "should produce exactly one message");
1067        match &messages[0] {
1068            Message::Assistant {
1069                content,
1070                tool_calls,
1071                reasoning_content,
1072                ..
1073            } => {
1074                assert_eq!(content, "I'll call the tool");
1075                assert_eq!(
1076                    reasoning_content.as_deref(),
1077                    Some("thinking about the problem")
1078                );
1079                assert_eq!(tool_calls.len(), 1);
1080                assert_eq!(tool_calls[0].function.name, "subtract");
1081            }
1082            _ => panic!("Expected assistant message"),
1083        }
1084    }
1085
1086    #[test]
1087    fn test_assistant_message_without_reasoning() {
1088        use crate::completion::message::{AssistantContent, Message as RigMessage};
1089
1090        let rig_msg = RigMessage::Assistant {
1091            id: None,
1092            content: OneOrMany::many(vec![
1093                AssistantContent::text("calling tool"),
1094                AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})),
1095            ])
1096            .expect("content should not be empty"),
1097        };
1098
1099        let messages: Vec<Message> = rig_msg.try_into().expect("conversion should succeed");
1100
1101        assert_eq!(messages.len(), 1);
1102        match &messages[0] {
1103            Message::Assistant {
1104                reasoning_content,
1105                tool_calls,
1106                ..
1107            } => {
1108                assert!(reasoning_content.is_none());
1109                assert_eq!(tool_calls.len(), 1);
1110            }
1111            _ => panic!("Expected assistant message"),
1112        }
1113    }
1114
1115    #[test]
1116    fn test_client_initialization() {
1117        let _client =
1118            crate::providers::deepseek::Client::new("dummy-key").expect("Client::new() failed");
1119        let _client_from_builder = crate::providers::deepseek::Client::builder()
1120            .api_key("dummy-key")
1121            .build()
1122            .expect("Client::builder() failed");
1123    }
1124}