rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51
52    const VERIFY_PATH: &'static str = "/user/balance";
53
54    fn build<H>(
55        _: &crate::client::ClientBuilder<
56            Self::Builder,
57            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58            H,
59        >,
60    ) -> http_client::Result<Self> {
61        Ok(Self)
62    }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Nothing;
68    type Transcription = Nothing;
69    #[cfg(feature = "image")]
70    type ImageGeneration = Nothing;
71    #[cfg(feature = "audio")]
72    type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78    type Output = DeepSeekExt;
79    type ApiKey = DeepSeekApiKey;
80
81    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88    type Input = DeepSeekApiKey;
89
90    // If you prefer the environment variable approach:
91    fn from_env() -> Self {
92        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93        let mut client_builder = Self::builder();
94        client_builder.headers_mut().insert(
95            http::header::CONTENT_TYPE,
96            http::HeaderValue::from_static("application/json"),
97        );
98        let client_builder = client_builder.api_key(&api_key);
99        client_builder.build().unwrap()
100    }
101
102    fn from_val(input: Self::Input) -> Self {
103        Self::new(input).unwrap()
104    }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109    message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115    Ok(T),
116    Err(ApiErrorResponse),
117}
118
119impl From<ApiErrorResponse> for CompletionError {
120    fn from(err: ApiErrorResponse) -> Self {
121        CompletionError::ProviderError(err.message)
122    }
123}
124
125/// The response shape from the DeepSeek API
126#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct CompletionResponse {
128    // We'll match the JSON:
129    pub choices: Vec<Choice>,
130    pub usage: Usage,
131    // you may want other fields
132}
133
134#[derive(Clone, Debug, Serialize, Deserialize, Default)]
135pub struct Usage {
136    pub completion_tokens: u32,
137    pub prompt_tokens: u32,
138    pub prompt_cache_hit_tokens: u32,
139    pub prompt_cache_miss_tokens: u32,
140    pub total_tokens: u32,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub completion_tokens_details: Option<CompletionTokensDetails>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub prompt_tokens_details: Option<PromptTokensDetails>,
145}
146
147impl Usage {
148    fn new() -> Self {
149        Self {
150            completion_tokens: 0,
151            prompt_tokens: 0,
152            prompt_cache_hit_tokens: 0,
153            prompt_cache_miss_tokens: 0,
154            total_tokens: 0,
155            completion_tokens_details: None,
156            prompt_tokens_details: None,
157        }
158    }
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct CompletionTokensDetails {
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub reasoning_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, Default)]
168pub struct PromptTokensDetails {
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub cached_tokens: Option<u32>,
171}
172
173#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
174pub struct Choice {
175    pub index: usize,
176    pub message: Message,
177    pub logprobs: Option<serde_json::Value>,
178    pub finish_reason: String,
179}
180
181#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
182#[serde(tag = "role", rename_all = "lowercase")]
183pub enum Message {
184    System {
185        content: String,
186        #[serde(skip_serializing_if = "Option::is_none")]
187        name: Option<String>,
188    },
189    User {
190        content: String,
191        #[serde(skip_serializing_if = "Option::is_none")]
192        name: Option<String>,
193    },
194    Assistant {
195        content: String,
196        #[serde(skip_serializing_if = "Option::is_none")]
197        name: Option<String>,
198        #[serde(
199            default,
200            deserialize_with = "json_utils::null_or_vec",
201            skip_serializing_if = "Vec::is_empty"
202        )]
203        tool_calls: Vec<ToolCall>,
204        /// only exists on `deepseek-reasoner` model at time of addition
205        #[serde(skip_serializing_if = "Option::is_none")]
206        reasoning_content: Option<String>,
207    },
208    #[serde(rename = "tool")]
209    ToolResult {
210        tool_call_id: String,
211        content: String,
212    },
213}
214
215impl Message {
216    pub fn system(content: &str) -> Self {
217        Message::System {
218            content: content.to_owned(),
219            name: None,
220        }
221    }
222}
223
224impl From<message::ToolResult> for Message {
225    fn from(tool_result: message::ToolResult) -> Self {
226        let content = match tool_result.content.first() {
227            message::ToolResultContent::Text(text) => text.text,
228            message::ToolResultContent::Image(_) => String::from("[Image]"),
229        };
230
231        Message::ToolResult {
232            tool_call_id: tool_result.id,
233            content,
234        }
235    }
236}
237
238impl From<message::ToolCall> for ToolCall {
239    fn from(tool_call: message::ToolCall) -> Self {
240        Self {
241            id: tool_call.id,
242            // TODO: update index when we have it
243            index: 0,
244            r#type: ToolType::Function,
245            function: Function {
246                name: tool_call.function.name,
247                arguments: tool_call.function.arguments,
248            },
249        }
250    }
251}
252
253impl TryFrom<message::Message> for Vec<Message> {
254    type Error = message::MessageError;
255
256    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
257        match message {
258            message::Message::User { content } => {
259                // extract tool results
260                let mut messages = vec![];
261
262                let tool_results = content
263                    .clone()
264                    .into_iter()
265                    .filter_map(|content| match content {
266                        message::UserContent::ToolResult(tool_result) => {
267                            Some(Message::from(tool_result))
268                        }
269                        _ => None,
270                    })
271                    .collect::<Vec<_>>();
272
273                messages.extend(tool_results);
274
275                // extract text results
276                let text_messages = content
277                    .into_iter()
278                    .filter_map(|content| match content {
279                        message::UserContent::Text(text) => Some(Message::User {
280                            content: text.text,
281                            name: None,
282                        }),
283                        message::UserContent::Document(Document {
284                            data:
285                                DocumentSourceKind::Base64(content)
286                                | DocumentSourceKind::String(content),
287                            ..
288                        }) => Some(Message::User {
289                            content,
290                            name: None,
291                        }),
292                        _ => None,
293                    })
294                    .collect::<Vec<_>>();
295                messages.extend(text_messages);
296
297                Ok(messages)
298            }
299            message::Message::Assistant { content, .. } => {
300                let mut messages: Vec<Message> = vec![];
301                let mut text_content = String::new();
302                let mut reasoning_content = String::new();
303
304                content.iter().for_each(|content| match content {
305                    message::AssistantContent::Text(text) => {
306                        text_content.push_str(text.text());
307                    }
308                    message::AssistantContent::Reasoning(reasoning) => {
309                        reasoning_content.push_str(&reasoning.reasoning.join("\n"));
310                    }
311                    _ => {}
312                });
313
314                messages.push(Message::Assistant {
315                    content: text_content,
316                    name: None,
317                    tool_calls: vec![],
318                    reasoning_content: if reasoning_content.is_empty() {
319                        None
320                    } else {
321                        Some(reasoning_content)
322                    },
323                });
324
325                // extract tool calls
326                let tool_calls = content
327                    .clone()
328                    .into_iter()
329                    .filter_map(|content| match content {
330                        message::AssistantContent::ToolCall(tool_call) => {
331                            Some(ToolCall::from(tool_call))
332                        }
333                        _ => None,
334                    })
335                    .collect::<Vec<_>>();
336
337                // if we have tool calls, we add a new Assistant message with them
338                if !tool_calls.is_empty() {
339                    messages.push(Message::Assistant {
340                        content: "".to_string(),
341                        name: None,
342                        tool_calls,
343                        reasoning_content: None,
344                    });
345                }
346
347                Ok(messages)
348            }
349        }
350    }
351}
352
353#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
354pub struct ToolCall {
355    pub id: String,
356    pub index: usize,
357    #[serde(default)]
358    pub r#type: ToolType,
359    pub function: Function,
360}
361
362#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
363pub struct Function {
364    pub name: String,
365    #[serde(with = "json_utils::stringified_json")]
366    pub arguments: serde_json::Value,
367}
368
369#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
370#[serde(rename_all = "lowercase")]
371pub enum ToolType {
372    #[default]
373    Function,
374}
375
376#[derive(Clone, Debug, Deserialize, Serialize)]
377pub struct ToolDefinition {
378    pub r#type: String,
379    pub function: completion::ToolDefinition,
380}
381
382impl From<crate::completion::ToolDefinition> for ToolDefinition {
383    fn from(tool: crate::completion::ToolDefinition) -> Self {
384        Self {
385            r#type: "function".into(),
386            function: tool,
387        }
388    }
389}
390
391impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
392    type Error = CompletionError;
393
394    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
395        let choice = response.choices.first().ok_or_else(|| {
396            CompletionError::ResponseError("Response contained no choices".to_owned())
397        })?;
398        let content = match &choice.message {
399            Message::Assistant {
400                content,
401                tool_calls,
402                ..
403            } => {
404                let mut content = if content.trim().is_empty() {
405                    vec![]
406                } else {
407                    vec![completion::AssistantContent::text(content)]
408                };
409
410                content.extend(
411                    tool_calls
412                        .iter()
413                        .map(|call| {
414                            completion::AssistantContent::tool_call(
415                                &call.id,
416                                &call.function.name,
417                                call.function.arguments.clone(),
418                            )
419                        })
420                        .collect::<Vec<_>>(),
421                );
422                Ok(content)
423            }
424            _ => Err(CompletionError::ResponseError(
425                "Response did not contain a valid message or tool call".into(),
426            )),
427        }?;
428
429        let choice = OneOrMany::many(content).map_err(|_| {
430            CompletionError::ResponseError(
431                "Response contained no message or tool call (empty)".to_owned(),
432            )
433        })?;
434
435        let usage = completion::Usage {
436            input_tokens: response.usage.prompt_tokens as u64,
437            output_tokens: response.usage.completion_tokens as u64,
438            total_tokens: response.usage.total_tokens as u64,
439        };
440
441        Ok(completion::CompletionResponse {
442            choice,
443            usage,
444            raw_response: response,
445        })
446    }
447}
448
449#[derive(Debug, Serialize, Deserialize)]
450pub(super) struct DeepseekCompletionRequest {
451    model: String,
452    pub messages: Vec<Message>,
453    #[serde(skip_serializing_if = "Option::is_none")]
454    temperature: Option<f64>,
455    #[serde(skip_serializing_if = "Vec::is_empty")]
456    tools: Vec<ToolDefinition>,
457    #[serde(skip_serializing_if = "Option::is_none")]
458    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
459    #[serde(flatten, skip_serializing_if = "Option::is_none")]
460    pub additional_params: Option<serde_json::Value>,
461}
462
463impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
464    type Error = CompletionError;
465
466    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
467        let mut full_history: Vec<Message> = match &req.preamble {
468            Some(preamble) => vec![Message::system(preamble)],
469            None => vec![],
470        };
471
472        if let Some(docs) = req.normalized_documents() {
473            let docs: Vec<Message> = docs.try_into()?;
474            full_history.extend(docs);
475        }
476
477        let chat_history: Vec<Message> = req
478            .chat_history
479            .clone()
480            .into_iter()
481            .map(|message| message.try_into())
482            .collect::<Result<Vec<Vec<Message>>, _>>()?
483            .into_iter()
484            .flatten()
485            .collect();
486
487        full_history.extend(chat_history);
488
489        let tool_choice = req
490            .tool_choice
491            .clone()
492            .map(crate::providers::openrouter::ToolChoice::try_from)
493            .transpose()?;
494
495        Ok(Self {
496            model: model.to_string(),
497            messages: full_history,
498            temperature: req.temperature,
499            tools: req
500                .tools
501                .clone()
502                .into_iter()
503                .map(ToolDefinition::from)
504                .collect::<Vec<_>>(),
505            tool_choice,
506            additional_params: req.additional_params,
507        })
508    }
509}
510
511/// The struct implementing the `CompletionModel` trait
512#[derive(Clone)]
513pub struct CompletionModel<T = reqwest::Client> {
514    pub client: Client<T>,
515    pub model: String,
516}
517
518impl<T> completion::CompletionModel for CompletionModel<T>
519where
520    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
521{
522    type Response = CompletionResponse;
523    type StreamingResponse = StreamingCompletionResponse;
524
525    type Client = Client<T>;
526
527    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
528        Self {
529            client: client.clone(),
530            model: model.into().to_string(),
531        }
532    }
533
534    async fn completion(
535        &self,
536        completion_request: CompletionRequest,
537    ) -> Result<
538        completion::CompletionResponse<CompletionResponse>,
539        crate::completion::CompletionError,
540    > {
541        let span = if tracing::Span::current().is_disabled() {
542            info_span!(
543                target: "rig::completions",
544                "chat",
545                gen_ai.operation.name = "chat",
546                gen_ai.provider.name = "deepseek",
547                gen_ai.request.model = self.model,
548                gen_ai.system_instructions = tracing::field::Empty,
549                gen_ai.response.id = tracing::field::Empty,
550                gen_ai.response.model = tracing::field::Empty,
551                gen_ai.usage.output_tokens = tracing::field::Empty,
552                gen_ai.usage.input_tokens = tracing::field::Empty,
553            )
554        } else {
555            tracing::Span::current()
556        };
557
558        span.record("gen_ai.system_instructions", &completion_request.preamble);
559
560        let request =
561            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
562
563        if enabled!(Level::TRACE) {
564            tracing::trace!(target: "rig::completions",
565                "DeepSeek completion request: {}",
566                serde_json::to_string_pretty(&request)?
567            );
568        }
569
570        let body = serde_json::to_vec(&request)?;
571        let req = self
572            .client
573            .post("/chat/completions")?
574            .body(body)
575            .map_err(|e| CompletionError::HttpError(e.into()))?;
576
577        async move {
578            let response = self.client.send::<_, Bytes>(req).await?;
579            let status = response.status();
580            let response_body = response.into_body().into_future().await?.to_vec();
581
582            if status.is_success() {
583                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
584                    ApiResponse::Ok(response) => {
585                        let span = tracing::Span::current();
586                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
587                        span.record(
588                            "gen_ai.usage.output_tokens",
589                            response.usage.completion_tokens,
590                        );
591                        if enabled!(Level::TRACE) {
592                            tracing::trace!(target: "rig::completions",
593                                "DeepSeek completion response: {}",
594                                serde_json::to_string_pretty(&response)?
595                            );
596                        }
597                        response.try_into()
598                    }
599                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
600                }
601            } else {
602                Err(CompletionError::ProviderError(
603                    String::from_utf8_lossy(&response_body).to_string(),
604                ))
605            }
606        }
607        .instrument(span)
608        .await
609    }
610
611    async fn stream(
612        &self,
613        completion_request: CompletionRequest,
614    ) -> Result<
615        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
616        CompletionError,
617    > {
618        let preamble = completion_request.preamble.clone();
619        let mut request =
620            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
621
622        let params = json_utils::merge(
623            request.additional_params.unwrap_or(serde_json::json!({})),
624            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
625        );
626
627        request.additional_params = Some(params);
628
629        if enabled!(Level::TRACE) {
630            tracing::trace!(target: "rig::completions",
631                "DeepSeek streaming completion request: {}",
632                serde_json::to_string_pretty(&request)?
633            );
634        }
635
636        let body = serde_json::to_vec(&request)?;
637
638        let req = self
639            .client
640            .post("/chat/completions")?
641            .body(body)
642            .map_err(|e| CompletionError::HttpError(e.into()))?;
643
644        let span = if tracing::Span::current().is_disabled() {
645            info_span!(
646                target: "rig::completions",
647                "chat_streaming",
648                gen_ai.operation.name = "chat_streaming",
649                gen_ai.provider.name = "deepseek",
650                gen_ai.request.model = self.model,
651                gen_ai.system_instructions = preamble,
652                gen_ai.response.id = tracing::field::Empty,
653                gen_ai.response.model = tracing::field::Empty,
654                gen_ai.usage.output_tokens = tracing::field::Empty,
655                gen_ai.usage.input_tokens = tracing::field::Empty,
656            )
657        } else {
658            tracing::Span::current()
659        };
660
661        tracing::Instrument::instrument(
662            send_compatible_streaming_request(self.client.clone(), req),
663            span,
664        )
665        .await
666    }
667}
668
669#[derive(Deserialize, Debug)]
670pub struct StreamingDelta {
671    #[serde(default)]
672    content: Option<String>,
673    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
674    tool_calls: Vec<StreamingToolCall>,
675    reasoning_content: Option<String>,
676}
677
678#[derive(Deserialize, Debug)]
679struct StreamingChoice {
680    delta: StreamingDelta,
681}
682
683#[derive(Deserialize, Debug)]
684struct StreamingCompletionChunk {
685    choices: Vec<StreamingChoice>,
686    usage: Option<Usage>,
687}
688
689#[derive(Clone, Deserialize, Serialize, Debug)]
690pub struct StreamingCompletionResponse {
691    pub usage: Usage,
692}
693
694impl GetTokenUsage for StreamingCompletionResponse {
695    fn token_usage(&self) -> Option<crate::completion::Usage> {
696        let mut usage = crate::completion::Usage::new();
697        usage.input_tokens = self.usage.prompt_tokens as u64;
698        usage.output_tokens = self.usage.completion_tokens as u64;
699        usage.total_tokens = self.usage.total_tokens as u64;
700
701        Some(usage)
702    }
703}
704
705pub async fn send_compatible_streaming_request<T>(
706    http_client: T,
707    req: Request<Vec<u8>>,
708) -> Result<
709    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
710    CompletionError,
711>
712where
713    T: HttpClientExt + Clone + 'static,
714{
715    let mut event_source = GenericEventSource::new(http_client, req);
716
717    let stream = stream! {
718        let mut final_usage = Usage::new();
719        let mut text_response = String::new();
720        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
721
722        while let Some(event_result) = event_source.next().await {
723            match event_result {
724                Ok(Event::Open) => {
725                    tracing::trace!("SSE connection opened");
726                    continue;
727                }
728                Ok(Event::Message(message)) => {
729                    if message.data.trim().is_empty() || message.data == "[DONE]" {
730                        continue;
731                    }
732
733                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
734                    let Ok(data) = parsed else {
735                        let err = parsed.unwrap_err();
736                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
737                        continue;
738                    };
739
740                    if let Some(choice) = data.choices.first() {
741                        let delta = &choice.delta;
742
743                        if !delta.tool_calls.is_empty() {
744                            for tool_call in &delta.tool_calls {
745                                let function = &tool_call.function;
746
747                                // Start of tool call
748                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
749                                    && empty_or_none(&function.arguments)
750                                {
751                                    let id = tool_call.id.clone().unwrap_or_default();
752                                    let name = function.name.clone().unwrap();
753                                    calls.insert(tool_call.index, (id, name, String::new()));
754                                }
755                                // Continuation of tool call
756                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
757                                    && let Some(arguments) = &function.arguments
758                                    && !arguments.is_empty()
759                                {
760                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
761                                        let combined = format!("{}{}", existing_args, arguments);
762                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
763                                    } else {
764                                        tracing::debug!("Partial tool call received but tool call was never started.");
765                                    }
766                                }
767                                // Complete tool call
768                                else {
769                                    let id = tool_call.id.clone().unwrap_or_default();
770                                    let name = function.name.clone().unwrap_or_default();
771                                    let arguments_str = function.arguments.clone().unwrap_or_default();
772
773                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
774                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
775                                        continue;
776                                    };
777
778                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
779                                        crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
780                                    ));
781                                }
782                            }
783                        }
784
785                        // DeepSeek-specific reasoning stream
786                        if let Some(content) = &delta.reasoning_content {
787                            yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
788                                id: None,
789                                reasoning: content.to_string()
790                            });
791                        }
792
793                        if let Some(content) = &delta.content {
794                            text_response += content;
795                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
796                        }
797                    }
798
799                    if let Some(usage) = data.usage {
800                        final_usage = usage.clone();
801                    }
802                }
803                Err(crate::http_client::Error::StreamEnded) => {
804                    break;
805                }
806                Err(err) => {
807                    tracing::error!(?err, "SSE error");
808                    yield Err(CompletionError::ResponseError(err.to_string()));
809                    break;
810                }
811            }
812        }
813
814        event_source.close();
815
816        let mut tool_calls = Vec::new();
817        // Flush accumulated tool calls
818        for (index, (id, name, arguments)) in calls {
819            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
820                continue;
821            };
822
823            tool_calls.push(ToolCall {
824                id: id.clone(),
825                index,
826                r#type: ToolType::Function,
827                function: Function {
828                    name: name.clone(),
829                    arguments: arguments_json.clone()
830                }
831            });
832            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
833                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
834            ));
835        }
836
837        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
838            StreamingCompletionResponse { usage: final_usage.clone() }
839        ));
840    };
841
842    Ok(crate::streaming::StreamingCompletionResponse::stream(
843        Box::pin(stream),
844    ))
845}
846
847// ================================================================
848// DeepSeek Completion API
849// ================================================================
850pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
851pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
852
853// Tests
854#[cfg(test)]
855mod tests {
856
857    use super::*;
858
859    #[test]
860    fn test_deserialize_vec_choice() {
861        let data = r#"[{
862            "finish_reason": "stop",
863            "index": 0,
864            "logprobs": null,
865            "message":{"role":"assistant","content":"Hello, world!"}
866            }]"#;
867
868        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
869        assert_eq!(choices.len(), 1);
870        match &choices.first().unwrap().message {
871            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
872            _ => panic!("Expected assistant message"),
873        }
874    }
875
876    #[test]
877    fn test_deserialize_deepseek_response() {
878        let data = r#"{
879            "choices":[{
880                "finish_reason": "stop",
881                "index": 0,
882                "logprobs": null,
883                "message":{"role":"assistant","content":"Hello, world!"}
884            }],
885            "usage": {
886                "completion_tokens": 0,
887                "prompt_tokens": 0,
888                "prompt_cache_hit_tokens": 0,
889                "prompt_cache_miss_tokens": 0,
890                "total_tokens": 0
891            }
892        }"#;
893
894        let jd = &mut serde_json::Deserializer::from_str(data);
895        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
896        match result {
897            Ok(response) => match &response.choices.first().unwrap().message {
898                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
899                _ => panic!("Expected assistant message"),
900            },
901            Err(err) => {
902                panic!("Deserialization error at {}: {}", err.path(), err);
903            }
904        }
905    }
906
907    #[test]
908    fn test_deserialize_example_response() {
909        let data = r#"
910        {
911            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
912            "object": "chat.completion",
913            "created": 0,
914            "model": "deepseek-chat",
915            "choices": [
916                {
917                    "index": 0,
918                    "message": {
919                        "role": "assistant",
920                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
921                    },
922                    "logprobs": null,
923                    "finish_reason": "stop"
924                }
925            ],
926            "usage": {
927                "prompt_tokens": 13,
928                "completion_tokens": 32,
929                "total_tokens": 45,
930                "prompt_tokens_details": {
931                    "cached_tokens": 0
932                },
933                "prompt_cache_hit_tokens": 0,
934                "prompt_cache_miss_tokens": 13
935            },
936            "system_fingerprint": "fp_4b6881f2c5"
937        }
938        "#;
939        let jd = &mut serde_json::Deserializer::from_str(data);
940        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
941
942        match result {
943            Ok(response) => match &response.choices.first().unwrap().message {
944                Message::Assistant { content, .. } => assert_eq!(
945                    content,
946                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
947                ),
948                _ => panic!("Expected assistant message"),
949            },
950            Err(err) => {
951                panic!("Deserialization error at {}: {}", err.path(), err);
952            }
953        }
954    }
955
956    #[test]
957    fn test_serialize_deserialize_tool_call_message() {
958        let tool_call_choice_json = r#"
959            {
960              "finish_reason": "tool_calls",
961              "index": 0,
962              "logprobs": null,
963              "message": {
964                "content": "",
965                "role": "assistant",
966                "tool_calls": [
967                  {
968                    "function": {
969                      "arguments": "{\"x\":2,\"y\":5}",
970                      "name": "subtract"
971                    },
972                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
973                    "index": 0,
974                    "type": "function"
975                  }
976                ]
977              }
978            }
979        "#;
980
981        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
982
983        let expected_choice: Choice = Choice {
984            finish_reason: "tool_calls".to_string(),
985            index: 0,
986            logprobs: None,
987            message: Message::Assistant {
988                content: "".to_string(),
989                name: None,
990                tool_calls: vec![ToolCall {
991                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
992                    function: Function {
993                        name: "subtract".to_string(),
994                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
995                    },
996                    index: 0,
997                    r#type: ToolType::Function,
998                }],
999                reasoning_content: None,
1000            },
1001        };
1002
1003        assert_eq!(choice, expected_choice);
1004    }
1005}