Skip to main content

rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51
52    const VERIFY_PATH: &'static str = "/user/balance";
53
54    fn build<H>(
55        _: &crate::client::ClientBuilder<
56            Self::Builder,
57            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58            H,
59        >,
60    ) -> http_client::Result<Self> {
61        Ok(Self)
62    }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Nothing;
68    type Transcription = Nothing;
69    #[cfg(feature = "image")]
70    type ImageGeneration = Nothing;
71    #[cfg(feature = "audio")]
72    type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78    type Output = DeepSeekExt;
79    type ApiKey = DeepSeekApiKey;
80
81    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88    type Input = DeepSeekApiKey;
89
90    // If you prefer the environment variable approach:
91    fn from_env() -> Self {
92        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93        let mut client_builder = Self::builder();
94        client_builder.headers_mut().insert(
95            http::header::CONTENT_TYPE,
96            http::HeaderValue::from_static("application/json"),
97        );
98        let client_builder = client_builder.api_key(&api_key);
99        client_builder.build().unwrap()
100    }
101
102    fn from_val(input: Self::Input) -> Self {
103        Self::new(input).unwrap()
104    }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109    message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115    Ok(T),
116    Err(ApiErrorResponse),
117}
118
119impl From<ApiErrorResponse> for CompletionError {
120    fn from(err: ApiErrorResponse) -> Self {
121        CompletionError::ProviderError(err.message)
122    }
123}
124
125/// The response shape from the DeepSeek API
126#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct CompletionResponse {
128    // We'll match the JSON:
129    pub choices: Vec<Choice>,
130    pub usage: Usage,
131    // you may want other fields
132}
133
134#[derive(Clone, Debug, Serialize, Deserialize, Default)]
135pub struct Usage {
136    pub completion_tokens: u32,
137    pub prompt_tokens: u32,
138    pub prompt_cache_hit_tokens: u32,
139    pub prompt_cache_miss_tokens: u32,
140    pub total_tokens: u32,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub completion_tokens_details: Option<CompletionTokensDetails>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub prompt_tokens_details: Option<PromptTokensDetails>,
145}
146
147impl Usage {
148    fn new() -> Self {
149        Self {
150            completion_tokens: 0,
151            prompt_tokens: 0,
152            prompt_cache_hit_tokens: 0,
153            prompt_cache_miss_tokens: 0,
154            total_tokens: 0,
155            completion_tokens_details: None,
156            prompt_tokens_details: None,
157        }
158    }
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct CompletionTokensDetails {
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub reasoning_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, Default)]
168pub struct PromptTokensDetails {
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub cached_tokens: Option<u32>,
171}
172
173#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
174pub struct Choice {
175    pub index: usize,
176    pub message: Message,
177    pub logprobs: Option<serde_json::Value>,
178    pub finish_reason: String,
179}
180
181#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
182#[serde(tag = "role", rename_all = "lowercase")]
183pub enum Message {
184    System {
185        content: String,
186        #[serde(skip_serializing_if = "Option::is_none")]
187        name: Option<String>,
188    },
189    User {
190        content: String,
191        #[serde(skip_serializing_if = "Option::is_none")]
192        name: Option<String>,
193    },
194    Assistant {
195        content: String,
196        #[serde(skip_serializing_if = "Option::is_none")]
197        name: Option<String>,
198        #[serde(
199            default,
200            deserialize_with = "json_utils::null_or_vec",
201            skip_serializing_if = "Vec::is_empty"
202        )]
203        tool_calls: Vec<ToolCall>,
204        /// only exists on `deepseek-reasoner` model at time of addition
205        #[serde(skip_serializing_if = "Option::is_none")]
206        reasoning_content: Option<String>,
207    },
208    #[serde(rename = "tool")]
209    ToolResult {
210        tool_call_id: String,
211        content: String,
212    },
213}
214
215impl Message {
216    pub fn system(content: &str) -> Self {
217        Message::System {
218            content: content.to_owned(),
219            name: None,
220        }
221    }
222}
223
224impl From<message::ToolResult> for Message {
225    fn from(tool_result: message::ToolResult) -> Self {
226        let content = match tool_result.content.first() {
227            message::ToolResultContent::Text(text) => text.text,
228            message::ToolResultContent::Image(_) => String::from("[Image]"),
229        };
230
231        Message::ToolResult {
232            tool_call_id: tool_result.id,
233            content,
234        }
235    }
236}
237
238impl From<message::ToolCall> for ToolCall {
239    fn from(tool_call: message::ToolCall) -> Self {
240        Self {
241            id: tool_call.id,
242            // TODO: update index when we have it
243            index: 0,
244            r#type: ToolType::Function,
245            function: Function {
246                name: tool_call.function.name,
247                arguments: tool_call.function.arguments,
248            },
249        }
250    }
251}
252
253impl TryFrom<message::Message> for Vec<Message> {
254    type Error = message::MessageError;
255
256    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
257        match message {
258            message::Message::User { content } => {
259                // extract tool results
260                let mut messages = vec![];
261
262                let tool_results = content
263                    .clone()
264                    .into_iter()
265                    .filter_map(|content| match content {
266                        message::UserContent::ToolResult(tool_result) => {
267                            Some(Message::from(tool_result))
268                        }
269                        _ => None,
270                    })
271                    .collect::<Vec<_>>();
272
273                messages.extend(tool_results);
274
275                // extract text results
276                let text_messages = content
277                    .into_iter()
278                    .filter_map(|content| match content {
279                        message::UserContent::Text(text) => Some(Message::User {
280                            content: text.text,
281                            name: None,
282                        }),
283                        message::UserContent::Document(Document {
284                            data:
285                                DocumentSourceKind::Base64(content)
286                                | DocumentSourceKind::String(content),
287                            ..
288                        }) => Some(Message::User {
289                            content,
290                            name: None,
291                        }),
292                        _ => None,
293                    })
294                    .collect::<Vec<_>>();
295                messages.extend(text_messages);
296
297                Ok(messages)
298            }
299            message::Message::Assistant { content, .. } => {
300                let mut messages: Vec<Message> = vec![];
301                let mut text_content = String::new();
302                let mut reasoning_content = String::new();
303
304                content.iter().for_each(|content| match content {
305                    message::AssistantContent::Text(text) => {
306                        text_content.push_str(text.text());
307                    }
308                    message::AssistantContent::Reasoning(reasoning) => {
309                        reasoning_content.push_str(&reasoning.reasoning.join("\n"));
310                    }
311                    _ => {}
312                });
313
314                messages.push(Message::Assistant {
315                    content: text_content,
316                    name: None,
317                    tool_calls: vec![],
318                    reasoning_content: if reasoning_content.is_empty() {
319                        None
320                    } else {
321                        Some(reasoning_content)
322                    },
323                });
324
325                // extract tool calls
326                let tool_calls = content
327                    .clone()
328                    .into_iter()
329                    .filter_map(|content| match content {
330                        message::AssistantContent::ToolCall(tool_call) => {
331                            Some(ToolCall::from(tool_call))
332                        }
333                        _ => None,
334                    })
335                    .collect::<Vec<_>>();
336
337                // if we have tool calls, we add a new Assistant message with them
338                if !tool_calls.is_empty() {
339                    messages.push(Message::Assistant {
340                        content: "".to_string(),
341                        name: None,
342                        tool_calls,
343                        reasoning_content: None,
344                    });
345                }
346
347                Ok(messages)
348            }
349        }
350    }
351}
352
353#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
354pub struct ToolCall {
355    pub id: String,
356    pub index: usize,
357    #[serde(default)]
358    pub r#type: ToolType,
359    pub function: Function,
360}
361
362#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
363pub struct Function {
364    pub name: String,
365    #[serde(with = "json_utils::stringified_json")]
366    pub arguments: serde_json::Value,
367}
368
369#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
370#[serde(rename_all = "lowercase")]
371pub enum ToolType {
372    #[default]
373    Function,
374}
375
376#[derive(Clone, Debug, Deserialize, Serialize)]
377pub struct ToolDefinition {
378    pub r#type: String,
379    pub function: completion::ToolDefinition,
380}
381
382impl From<crate::completion::ToolDefinition> for ToolDefinition {
383    fn from(tool: crate::completion::ToolDefinition) -> Self {
384        Self {
385            r#type: "function".into(),
386            function: tool,
387        }
388    }
389}
390
391impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
392    type Error = CompletionError;
393
394    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
395        let choice = response.choices.first().ok_or_else(|| {
396            CompletionError::ResponseError("Response contained no choices".to_owned())
397        })?;
398        let content = match &choice.message {
399            Message::Assistant {
400                content,
401                tool_calls,
402                reasoning_content,
403                ..
404            } => {
405                let mut content = if content.trim().is_empty() {
406                    vec![]
407                } else {
408                    vec![completion::AssistantContent::text(content)]
409                };
410
411                content.extend(
412                    tool_calls
413                        .iter()
414                        .map(|call| {
415                            completion::AssistantContent::tool_call(
416                                &call.id,
417                                &call.function.name,
418                                call.function.arguments.clone(),
419                            )
420                        })
421                        .collect::<Vec<_>>(),
422                );
423
424                if let Some(reasoning_content) = reasoning_content {
425                    content.push(completion::AssistantContent::reasoning(reasoning_content));
426                }
427
428                Ok(content)
429            }
430            _ => Err(CompletionError::ResponseError(
431                "Response did not contain a valid message or tool call".into(),
432            )),
433        }?;
434
435        let choice = OneOrMany::many(content).map_err(|_| {
436            CompletionError::ResponseError(
437                "Response contained no message or tool call (empty)".to_owned(),
438            )
439        })?;
440
441        let usage = completion::Usage {
442            input_tokens: response.usage.prompt_tokens as u64,
443            output_tokens: response.usage.completion_tokens as u64,
444            total_tokens: response.usage.total_tokens as u64,
445            cached_input_tokens: response
446                .usage
447                .prompt_tokens_details
448                .as_ref()
449                .and_then(|d| d.cached_tokens)
450                .map(|c| c as u64)
451                .unwrap_or(0),
452        };
453
454        Ok(completion::CompletionResponse {
455            choice,
456            usage,
457            raw_response: response,
458        })
459    }
460}
461
462#[derive(Debug, Serialize, Deserialize)]
463pub(super) struct DeepseekCompletionRequest {
464    model: String,
465    pub messages: Vec<Message>,
466    #[serde(skip_serializing_if = "Option::is_none")]
467    temperature: Option<f64>,
468    #[serde(skip_serializing_if = "Vec::is_empty")]
469    tools: Vec<ToolDefinition>,
470    #[serde(skip_serializing_if = "Option::is_none")]
471    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
472    #[serde(flatten, skip_serializing_if = "Option::is_none")]
473    pub additional_params: Option<serde_json::Value>,
474}
475
476impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
477    type Error = CompletionError;
478
479    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
480        let mut full_history: Vec<Message> = match &req.preamble {
481            Some(preamble) => vec![Message::system(preamble)],
482            None => vec![],
483        };
484
485        if let Some(docs) = req.normalized_documents() {
486            let docs: Vec<Message> = docs.try_into()?;
487            full_history.extend(docs);
488        }
489
490        let chat_history: Vec<Message> = req
491            .chat_history
492            .clone()
493            .into_iter()
494            .map(|message| message.try_into())
495            .collect::<Result<Vec<Vec<Message>>, _>>()?
496            .into_iter()
497            .flatten()
498            .collect();
499
500        let mut last_reasoning_content = None;
501        let mut last_assistant_idx = None;
502
503        for message in chat_history {
504            if let Message::Assistant {
505                reasoning_content, ..
506            } = &message
507            {
508                if let Some(content) = reasoning_content {
509                    last_reasoning_content = Some(content.clone());
510                } else {
511                    last_assistant_idx = Some(full_history.len());
512                    full_history.push(message);
513                }
514            } else {
515                full_history.push(message);
516            }
517        }
518
519        // Merge last reasoning content into the last assistant message.
520        // Note that we only need to preserve the last reasoning content.
521        if let (Some(idx), Some(reasoning)) = (last_assistant_idx, last_reasoning_content)
522            && let Message::Assistant {
523                ref mut reasoning_content,
524                ..
525            } = full_history[idx]
526        {
527            *reasoning_content = Some(reasoning);
528        }
529
530        let tool_choice = req
531            .tool_choice
532            .clone()
533            .map(crate::providers::openrouter::ToolChoice::try_from)
534            .transpose()?;
535
536        Ok(Self {
537            model: model.to_string(),
538            messages: full_history,
539            temperature: req.temperature,
540            tools: req
541                .tools
542                .clone()
543                .into_iter()
544                .map(ToolDefinition::from)
545                .collect::<Vec<_>>(),
546            tool_choice,
547            additional_params: req.additional_params,
548        })
549    }
550}
551
552/// The struct implementing the `CompletionModel` trait
553#[derive(Clone)]
554pub struct CompletionModel<T = reqwest::Client> {
555    pub client: Client<T>,
556    pub model: String,
557}
558
559impl<T> completion::CompletionModel for CompletionModel<T>
560where
561    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
562{
563    type Response = CompletionResponse;
564    type StreamingResponse = StreamingCompletionResponse;
565
566    type Client = Client<T>;
567
568    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
569        Self {
570            client: client.clone(),
571            model: model.into().to_string(),
572        }
573    }
574
575    async fn completion(
576        &self,
577        completion_request: CompletionRequest,
578    ) -> Result<
579        completion::CompletionResponse<CompletionResponse>,
580        crate::completion::CompletionError,
581    > {
582        let span = if tracing::Span::current().is_disabled() {
583            info_span!(
584                target: "rig::completions",
585                "chat",
586                gen_ai.operation.name = "chat",
587                gen_ai.provider.name = "deepseek",
588                gen_ai.request.model = self.model,
589                gen_ai.system_instructions = tracing::field::Empty,
590                gen_ai.response.id = tracing::field::Empty,
591                gen_ai.response.model = tracing::field::Empty,
592                gen_ai.usage.output_tokens = tracing::field::Empty,
593                gen_ai.usage.input_tokens = tracing::field::Empty,
594            )
595        } else {
596            tracing::Span::current()
597        };
598
599        span.record("gen_ai.system_instructions", &completion_request.preamble);
600
601        let request =
602            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
603
604        if enabled!(Level::TRACE) {
605            tracing::trace!(target: "rig::completions",
606                "DeepSeek completion request: {}",
607                serde_json::to_string_pretty(&request)?
608            );
609        }
610
611        let body = serde_json::to_vec(&request)?;
612        let req = self
613            .client
614            .post("/chat/completions")?
615            .body(body)
616            .map_err(|e| CompletionError::HttpError(e.into()))?;
617
618        async move {
619            let response = self.client.send::<_, Bytes>(req).await?;
620            let status = response.status();
621            let response_body = response.into_body().into_future().await?.to_vec();
622
623            if status.is_success() {
624                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
625                    ApiResponse::Ok(response) => {
626                        let span = tracing::Span::current();
627                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
628                        span.record(
629                            "gen_ai.usage.output_tokens",
630                            response.usage.completion_tokens,
631                        );
632                        if enabled!(Level::TRACE) {
633                            tracing::trace!(target: "rig::completions",
634                                "DeepSeek completion response: {}",
635                                serde_json::to_string_pretty(&response)?
636                            );
637                        }
638                        response.try_into()
639                    }
640                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
641                }
642            } else {
643                Err(CompletionError::ProviderError(
644                    String::from_utf8_lossy(&response_body).to_string(),
645                ))
646            }
647        }
648        .instrument(span)
649        .await
650    }
651
652    async fn stream(
653        &self,
654        completion_request: CompletionRequest,
655    ) -> Result<
656        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
657        CompletionError,
658    > {
659        let preamble = completion_request.preamble.clone();
660        let mut request =
661            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
662
663        let params = json_utils::merge(
664            request.additional_params.unwrap_or(serde_json::json!({})),
665            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
666        );
667
668        request.additional_params = Some(params);
669
670        if enabled!(Level::TRACE) {
671            tracing::trace!(target: "rig::completions",
672                "DeepSeek streaming completion request: {}",
673                serde_json::to_string_pretty(&request)?
674            );
675        }
676
677        let body = serde_json::to_vec(&request)?;
678
679        let req = self
680            .client
681            .post("/chat/completions")?
682            .body(body)
683            .map_err(|e| CompletionError::HttpError(e.into()))?;
684
685        let span = if tracing::Span::current().is_disabled() {
686            info_span!(
687                target: "rig::completions",
688                "chat_streaming",
689                gen_ai.operation.name = "chat_streaming",
690                gen_ai.provider.name = "deepseek",
691                gen_ai.request.model = self.model,
692                gen_ai.system_instructions = preamble,
693                gen_ai.response.id = tracing::field::Empty,
694                gen_ai.response.model = tracing::field::Empty,
695                gen_ai.usage.output_tokens = tracing::field::Empty,
696                gen_ai.usage.input_tokens = tracing::field::Empty,
697            )
698        } else {
699            tracing::Span::current()
700        };
701
702        tracing::Instrument::instrument(
703            send_compatible_streaming_request(self.client.clone(), req),
704            span,
705        )
706        .await
707    }
708}
709
710#[derive(Deserialize, Debug)]
711pub struct StreamingDelta {
712    #[serde(default)]
713    content: Option<String>,
714    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
715    tool_calls: Vec<StreamingToolCall>,
716    reasoning_content: Option<String>,
717}
718
719#[derive(Deserialize, Debug)]
720struct StreamingChoice {
721    delta: StreamingDelta,
722}
723
724#[derive(Deserialize, Debug)]
725struct StreamingCompletionChunk {
726    choices: Vec<StreamingChoice>,
727    usage: Option<Usage>,
728}
729
730#[derive(Clone, Deserialize, Serialize, Debug)]
731pub struct StreamingCompletionResponse {
732    pub usage: Usage,
733}
734
735impl GetTokenUsage for StreamingCompletionResponse {
736    fn token_usage(&self) -> Option<crate::completion::Usage> {
737        let mut usage = crate::completion::Usage::new();
738        usage.input_tokens = self.usage.prompt_tokens as u64;
739        usage.output_tokens = self.usage.completion_tokens as u64;
740        usage.total_tokens = self.usage.total_tokens as u64;
741        usage.cached_input_tokens = self
742            .usage
743            .prompt_tokens_details
744            .as_ref()
745            .and_then(|d| d.cached_tokens)
746            .map(|c| c as u64)
747            .unwrap_or(0);
748
749        Some(usage)
750    }
751}
752
753pub async fn send_compatible_streaming_request<T>(
754    http_client: T,
755    req: Request<Vec<u8>>,
756) -> Result<
757    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
758    CompletionError,
759>
760where
761    T: HttpClientExt + Clone + 'static,
762{
763    let mut event_source = GenericEventSource::new(http_client, req);
764
765    let stream = stream! {
766        let mut final_usage = Usage::new();
767        let mut text_response = String::new();
768        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
769
770        while let Some(event_result) = event_source.next().await {
771            match event_result {
772                Ok(Event::Open) => {
773                    tracing::trace!("SSE connection opened");
774                    continue;
775                }
776                Ok(Event::Message(message)) => {
777                    if message.data.trim().is_empty() || message.data == "[DONE]" {
778                        continue;
779                    }
780
781                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
782                    let Ok(data) = parsed else {
783                        let err = parsed.unwrap_err();
784                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
785                        continue;
786                    };
787
788                    if let Some(choice) = data.choices.first() {
789                        let delta = &choice.delta;
790
791                        if !delta.tool_calls.is_empty() {
792                            for tool_call in &delta.tool_calls {
793                                let function = &tool_call.function;
794
795                                // Start of tool call
796                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
797                                    && empty_or_none(&function.arguments)
798                                {
799                                    let id = tool_call.id.clone().unwrap_or_default();
800                                    let name = function.name.clone().unwrap();
801                                    calls.insert(tool_call.index, (id, name, String::new()));
802                                }
803                                // Continuation of tool call
804                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
805                                    && let Some(arguments) = &function.arguments
806                                    && !arguments.is_empty()
807                                {
808                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
809                                        let combined = format!("{}{}", existing_args, arguments);
810                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
811                                    } else {
812                                        tracing::debug!("Partial tool call received but tool call was never started.");
813                                    }
814                                }
815                                // Complete tool call
816                                else {
817                                    let id = tool_call.id.clone().unwrap_or_default();
818                                    let name = function.name.clone().unwrap_or_default();
819                                    let arguments_str = function.arguments.clone().unwrap_or_default();
820
821                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
822                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
823                                        continue;
824                                    };
825
826                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
827                                        crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
828                                    ));
829                                }
830                            }
831                        }
832
833                        // DeepSeek-specific reasoning stream
834                        if let Some(content) = &delta.reasoning_content {
835                            yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
836                                id: None,
837                                reasoning: content.to_string()
838                            });
839                        }
840
841                        if let Some(content) = &delta.content {
842                            text_response += content;
843                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
844                        }
845                    }
846
847                    if let Some(usage) = data.usage {
848                        final_usage = usage.clone();
849                    }
850                }
851                Err(crate::http_client::Error::StreamEnded) => {
852                    break;
853                }
854                Err(err) => {
855                    tracing::error!(?err, "SSE error");
856                    yield Err(CompletionError::ResponseError(err.to_string()));
857                    break;
858                }
859            }
860        }
861
862        event_source.close();
863
864        let mut tool_calls = Vec::new();
865        // Flush accumulated tool calls
866        for (index, (id, name, arguments)) in calls {
867            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
868                continue;
869            };
870
871            tool_calls.push(ToolCall {
872                id: id.clone(),
873                index,
874                r#type: ToolType::Function,
875                function: Function {
876                    name: name.clone(),
877                    arguments: arguments_json.clone()
878                }
879            });
880            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
881                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
882            ));
883        }
884
885        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
886            StreamingCompletionResponse { usage: final_usage.clone() }
887        ));
888    };
889
890    Ok(crate::streaming::StreamingCompletionResponse::stream(
891        Box::pin(stream),
892    ))
893}
894
895// ================================================================
896// DeepSeek Completion API
897// ================================================================
898pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
899pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
900
901// Tests
902#[cfg(test)]
903mod tests {
904
905    use super::*;
906
907    #[test]
908    fn test_deserialize_vec_choice() {
909        let data = r#"[{
910            "finish_reason": "stop",
911            "index": 0,
912            "logprobs": null,
913            "message":{"role":"assistant","content":"Hello, world!"}
914            }]"#;
915
916        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
917        assert_eq!(choices.len(), 1);
918        match &choices.first().unwrap().message {
919            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
920            _ => panic!("Expected assistant message"),
921        }
922    }
923
924    #[test]
925    fn test_deserialize_deepseek_response() {
926        let data = r#"{
927            "choices":[{
928                "finish_reason": "stop",
929                "index": 0,
930                "logprobs": null,
931                "message":{"role":"assistant","content":"Hello, world!"}
932            }],
933            "usage": {
934                "completion_tokens": 0,
935                "prompt_tokens": 0,
936                "prompt_cache_hit_tokens": 0,
937                "prompt_cache_miss_tokens": 0,
938                "total_tokens": 0
939            }
940        }"#;
941
942        let jd = &mut serde_json::Deserializer::from_str(data);
943        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
944        match result {
945            Ok(response) => match &response.choices.first().unwrap().message {
946                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
947                _ => panic!("Expected assistant message"),
948            },
949            Err(err) => {
950                panic!("Deserialization error at {}: {}", err.path(), err);
951            }
952        }
953    }
954
955    #[test]
956    fn test_deserialize_example_response() {
957        let data = r#"
958        {
959            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
960            "object": "chat.completion",
961            "created": 0,
962            "model": "deepseek-chat",
963            "choices": [
964                {
965                    "index": 0,
966                    "message": {
967                        "role": "assistant",
968                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
969                    },
970                    "logprobs": null,
971                    "finish_reason": "stop"
972                }
973            ],
974            "usage": {
975                "prompt_tokens": 13,
976                "completion_tokens": 32,
977                "total_tokens": 45,
978                "prompt_tokens_details": {
979                    "cached_tokens": 0
980                },
981                "prompt_cache_hit_tokens": 0,
982                "prompt_cache_miss_tokens": 13
983            },
984            "system_fingerprint": "fp_4b6881f2c5"
985        }
986        "#;
987        let jd = &mut serde_json::Deserializer::from_str(data);
988        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
989
990        match result {
991            Ok(response) => match &response.choices.first().unwrap().message {
992                Message::Assistant { content, .. } => assert_eq!(
993                    content,
994                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
995                ),
996                _ => panic!("Expected assistant message"),
997            },
998            Err(err) => {
999                panic!("Deserialization error at {}: {}", err.path(), err);
1000            }
1001        }
1002    }
1003
1004    #[test]
1005    fn test_serialize_deserialize_tool_call_message() {
1006        let tool_call_choice_json = r#"
1007            {
1008              "finish_reason": "tool_calls",
1009              "index": 0,
1010              "logprobs": null,
1011              "message": {
1012                "content": "",
1013                "role": "assistant",
1014                "tool_calls": [
1015                  {
1016                    "function": {
1017                      "arguments": "{\"x\":2,\"y\":5}",
1018                      "name": "subtract"
1019                    },
1020                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1021                    "index": 0,
1022                    "type": "function"
1023                  }
1024                ]
1025              }
1026            }
1027        "#;
1028
1029        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1030
1031        let expected_choice: Choice = Choice {
1032            finish_reason: "tool_calls".to_string(),
1033            index: 0,
1034            logprobs: None,
1035            message: Message::Assistant {
1036                content: "".to_string(),
1037                name: None,
1038                tool_calls: vec![ToolCall {
1039                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1040                    function: Function {
1041                        name: "subtract".to_string(),
1042                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1043                    },
1044                    index: 0,
1045                    r#type: ToolType::Function,
1046                }],
1047                reasoning_content: None,
1048            },
1049        };
1050
1051        assert_eq!(choice, expected_choice);
1052    }
1053}