rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51
52    const VERIFY_PATH: &'static str = "/user/balance";
53
54    fn build<H>(
55        _: &crate::client::ClientBuilder<
56            Self::Builder,
57            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58            H,
59        >,
60    ) -> http_client::Result<Self> {
61        Ok(Self)
62    }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Nothing;
68    type Transcription = Nothing;
69    #[cfg(feature = "image")]
70    type ImageGeneration = Nothing;
71    #[cfg(feature = "audio")]
72    type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78    type Output = DeepSeekExt;
79    type ApiKey = DeepSeekApiKey;
80
81    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88    type Input = DeepSeekApiKey;
89
90    // If you prefer the environment variable approach:
91    fn from_env() -> Self {
92        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93        Self::new(&api_key).unwrap()
94    }
95
96    fn from_val(input: Self::Input) -> Self {
97        Self::new(input).unwrap()
98    }
99}
100
101#[derive(Debug, Deserialize)]
102struct ApiErrorResponse {
103    message: String,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(untagged)]
108enum ApiResponse<T> {
109    Ok(T),
110    Err(ApiErrorResponse),
111}
112
113impl From<ApiErrorResponse> for CompletionError {
114    fn from(err: ApiErrorResponse) -> Self {
115        CompletionError::ProviderError(err.message)
116    }
117}
118
119/// The response shape from the DeepSeek API
120#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct CompletionResponse {
122    // We'll match the JSON:
123    pub choices: Vec<Choice>,
124    pub usage: Usage,
125    // you may want other fields
126}
127
128#[derive(Clone, Debug, Serialize, Deserialize, Default)]
129pub struct Usage {
130    pub completion_tokens: u32,
131    pub prompt_tokens: u32,
132    pub prompt_cache_hit_tokens: u32,
133    pub prompt_cache_miss_tokens: u32,
134    pub total_tokens: u32,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub completion_tokens_details: Option<CompletionTokensDetails>,
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub prompt_tokens_details: Option<PromptTokensDetails>,
139}
140
141impl Usage {
142    fn new() -> Self {
143        Self {
144            completion_tokens: 0,
145            prompt_tokens: 0,
146            prompt_cache_hit_tokens: 0,
147            prompt_cache_miss_tokens: 0,
148            total_tokens: 0,
149            completion_tokens_details: None,
150            prompt_tokens_details: None,
151        }
152    }
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize, Default)]
156pub struct CompletionTokensDetails {
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub reasoning_tokens: Option<u32>,
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct PromptTokensDetails {
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub cached_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
168pub struct Choice {
169    pub index: usize,
170    pub message: Message,
171    pub logprobs: Option<serde_json::Value>,
172    pub finish_reason: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(tag = "role", rename_all = "lowercase")]
177pub enum Message {
178    System {
179        content: String,
180        #[serde(skip_serializing_if = "Option::is_none")]
181        name: Option<String>,
182    },
183    User {
184        content: String,
185        #[serde(skip_serializing_if = "Option::is_none")]
186        name: Option<String>,
187    },
188    Assistant {
189        content: String,
190        #[serde(skip_serializing_if = "Option::is_none")]
191        name: Option<String>,
192        #[serde(
193            default,
194            deserialize_with = "json_utils::null_or_vec",
195            skip_serializing_if = "Vec::is_empty"
196        )]
197        tool_calls: Vec<ToolCall>,
198    },
199    #[serde(rename = "tool")]
200    ToolResult {
201        tool_call_id: String,
202        content: String,
203    },
204}
205
206impl Message {
207    pub fn system(content: &str) -> Self {
208        Message::System {
209            content: content.to_owned(),
210            name: None,
211        }
212    }
213}
214
215impl From<message::ToolResult> for Message {
216    fn from(tool_result: message::ToolResult) -> Self {
217        let content = match tool_result.content.first() {
218            message::ToolResultContent::Text(text) => text.text,
219            message::ToolResultContent::Image(_) => String::from("[Image]"),
220        };
221
222        Message::ToolResult {
223            tool_call_id: tool_result.id,
224            content,
225        }
226    }
227}
228
229impl From<message::ToolCall> for ToolCall {
230    fn from(tool_call: message::ToolCall) -> Self {
231        Self {
232            id: tool_call.id,
233            // TODO: update index when we have it
234            index: 0,
235            r#type: ToolType::Function,
236            function: Function {
237                name: tool_call.function.name,
238                arguments: tool_call.function.arguments,
239            },
240        }
241    }
242}
243
244impl TryFrom<message::Message> for Vec<Message> {
245    type Error = message::MessageError;
246
247    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248        match message {
249            message::Message::User { content } => {
250                // extract tool results
251                let mut messages = vec![];
252
253                let tool_results = content
254                    .clone()
255                    .into_iter()
256                    .filter_map(|content| match content {
257                        message::UserContent::ToolResult(tool_result) => {
258                            Some(Message::from(tool_result))
259                        }
260                        _ => None,
261                    })
262                    .collect::<Vec<_>>();
263
264                messages.extend(tool_results);
265
266                // extract text results
267                let text_messages = content
268                    .into_iter()
269                    .filter_map(|content| match content {
270                        message::UserContent::Text(text) => Some(Message::User {
271                            content: text.text,
272                            name: None,
273                        }),
274                        message::UserContent::Document(Document {
275                            data:
276                                DocumentSourceKind::Base64(content)
277                                | DocumentSourceKind::String(content),
278                            ..
279                        }) => Some(Message::User {
280                            content,
281                            name: None,
282                        }),
283                        _ => None,
284                    })
285                    .collect::<Vec<_>>();
286                messages.extend(text_messages);
287
288                Ok(messages)
289            }
290            message::Message::Assistant { content, .. } => {
291                let mut messages: Vec<Message> = vec![];
292
293                // extract text
294                let text_content = content
295                    .clone()
296                    .into_iter()
297                    .filter_map(|content| match content {
298                        message::AssistantContent::Text(text) => Some(Message::Assistant {
299                            content: text.text,
300                            name: None,
301                            tool_calls: vec![],
302                        }),
303                        _ => None,
304                    })
305                    .collect::<Vec<_>>();
306
307                messages.extend(text_content);
308
309                // extract tool calls
310                let tool_calls = content
311                    .clone()
312                    .into_iter()
313                    .filter_map(|content| match content {
314                        message::AssistantContent::ToolCall(tool_call) => {
315                            Some(ToolCall::from(tool_call))
316                        }
317                        _ => None,
318                    })
319                    .collect::<Vec<_>>();
320
321                // if we have tool calls, we add a new Assistant message with them
322                if !tool_calls.is_empty() {
323                    messages.push(Message::Assistant {
324                        content: "".to_string(),
325                        name: None,
326                        tool_calls,
327                    });
328                }
329
330                Ok(messages)
331            }
332        }
333    }
334}
335
336#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
337pub struct ToolCall {
338    pub id: String,
339    pub index: usize,
340    #[serde(default)]
341    pub r#type: ToolType,
342    pub function: Function,
343}
344
345#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
346pub struct Function {
347    pub name: String,
348    #[serde(with = "json_utils::stringified_json")]
349    pub arguments: serde_json::Value,
350}
351
352#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
353#[serde(rename_all = "lowercase")]
354pub enum ToolType {
355    #[default]
356    Function,
357}
358
359#[derive(Clone, Debug, Deserialize, Serialize)]
360pub struct ToolDefinition {
361    pub r#type: String,
362    pub function: completion::ToolDefinition,
363}
364
365impl From<crate::completion::ToolDefinition> for ToolDefinition {
366    fn from(tool: crate::completion::ToolDefinition) -> Self {
367        Self {
368            r#type: "function".into(),
369            function: tool,
370        }
371    }
372}
373
374impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
375    type Error = CompletionError;
376
377    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
378        let choice = response.choices.first().ok_or_else(|| {
379            CompletionError::ResponseError("Response contained no choices".to_owned())
380        })?;
381        let content = match &choice.message {
382            Message::Assistant {
383                content,
384                tool_calls,
385                ..
386            } => {
387                let mut content = if content.trim().is_empty() {
388                    vec![]
389                } else {
390                    vec![completion::AssistantContent::text(content)]
391                };
392
393                content.extend(
394                    tool_calls
395                        .iter()
396                        .map(|call| {
397                            completion::AssistantContent::tool_call(
398                                &call.id,
399                                &call.function.name,
400                                call.function.arguments.clone(),
401                            )
402                        })
403                        .collect::<Vec<_>>(),
404                );
405                Ok(content)
406            }
407            _ => Err(CompletionError::ResponseError(
408                "Response did not contain a valid message or tool call".into(),
409            )),
410        }?;
411
412        let choice = OneOrMany::many(content).map_err(|_| {
413            CompletionError::ResponseError(
414                "Response contained no message or tool call (empty)".to_owned(),
415            )
416        })?;
417
418        let usage = completion::Usage {
419            input_tokens: response.usage.prompt_tokens as u64,
420            output_tokens: response.usage.completion_tokens as u64,
421            total_tokens: response.usage.total_tokens as u64,
422        };
423
424        Ok(completion::CompletionResponse {
425            choice,
426            usage,
427            raw_response: response,
428        })
429    }
430}
431
432#[derive(Debug, Serialize, Deserialize)]
433pub(super) struct DeepseekCompletionRequest {
434    model: String,
435    pub messages: Vec<Message>,
436    #[serde(skip_serializing_if = "Option::is_none")]
437    temperature: Option<f64>,
438    #[serde(skip_serializing_if = "Vec::is_empty")]
439    tools: Vec<ToolDefinition>,
440    #[serde(skip_serializing_if = "Option::is_none")]
441    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
442    #[serde(flatten, skip_serializing_if = "Option::is_none")]
443    pub additional_params: Option<serde_json::Value>,
444}
445
446impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
447    type Error = CompletionError;
448
449    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
450        let mut full_history: Vec<Message> = match &req.preamble {
451            Some(preamble) => vec![Message::system(preamble)],
452            None => vec![],
453        };
454
455        if let Some(docs) = req.normalized_documents() {
456            let docs: Vec<Message> = docs.try_into()?;
457            full_history.extend(docs);
458        }
459
460        let chat_history: Vec<Message> = req
461            .chat_history
462            .clone()
463            .into_iter()
464            .map(|message| message.try_into())
465            .collect::<Result<Vec<Vec<Message>>, _>>()?
466            .into_iter()
467            .flatten()
468            .collect();
469
470        full_history.extend(chat_history);
471
472        let tool_choice = req
473            .tool_choice
474            .clone()
475            .map(crate::providers::openrouter::ToolChoice::try_from)
476            .transpose()?;
477
478        Ok(Self {
479            model: model.to_string(),
480            messages: full_history,
481            temperature: req.temperature,
482            tools: req
483                .tools
484                .clone()
485                .into_iter()
486                .map(ToolDefinition::from)
487                .collect::<Vec<_>>(),
488            tool_choice,
489            additional_params: req.additional_params,
490        })
491    }
492}
493
494/// The struct implementing the `CompletionModel` trait
495#[derive(Clone)]
496pub struct CompletionModel<T = reqwest::Client> {
497    pub client: Client<T>,
498    pub model: String,
499}
500
501impl<T> completion::CompletionModel for CompletionModel<T>
502where
503    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
504{
505    type Response = CompletionResponse;
506    type StreamingResponse = StreamingCompletionResponse;
507
508    type Client = Client<T>;
509
510    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
511        Self {
512            client: client.clone(),
513            model: model.into().to_string(),
514        }
515    }
516
517    async fn completion(
518        &self,
519        completion_request: CompletionRequest,
520    ) -> Result<
521        completion::CompletionResponse<CompletionResponse>,
522        crate::completion::CompletionError,
523    > {
524        let span = if tracing::Span::current().is_disabled() {
525            info_span!(
526                target: "rig::completions",
527                "chat",
528                gen_ai.operation.name = "chat",
529                gen_ai.provider.name = "deepseek",
530                gen_ai.request.model = self.model,
531                gen_ai.system_instructions = tracing::field::Empty,
532                gen_ai.response.id = tracing::field::Empty,
533                gen_ai.response.model = tracing::field::Empty,
534                gen_ai.usage.output_tokens = tracing::field::Empty,
535                gen_ai.usage.input_tokens = tracing::field::Empty,
536            )
537        } else {
538            tracing::Span::current()
539        };
540
541        span.record("gen_ai.system_instructions", &completion_request.preamble);
542
543        let request =
544            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
545
546        if enabled!(Level::TRACE) {
547            tracing::trace!(target: "rig::completions",
548                "DeepSeek completion request: {}",
549                serde_json::to_string_pretty(&request)?
550            );
551        }
552
553        let body = serde_json::to_vec(&request)?;
554        let req = self
555            .client
556            .post("/chat/completions")?
557            .body(body)
558            .map_err(|e| CompletionError::HttpError(e.into()))?;
559
560        async move {
561            let response = self.client.send::<_, Bytes>(req).await?;
562            let status = response.status();
563            let response_body = response.into_body().into_future().await?.to_vec();
564
565            if status.is_success() {
566                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
567                    ApiResponse::Ok(response) => {
568                        let span = tracing::Span::current();
569                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
570                        span.record(
571                            "gen_ai.usage.output_tokens",
572                            response.usage.completion_tokens,
573                        );
574                        if enabled!(Level::TRACE) {
575                            tracing::trace!(target: "rig::completions",
576                                "DeepSeek completion response: {}",
577                                serde_json::to_string_pretty(&response)?
578                            );
579                        }
580                        response.try_into()
581                    }
582                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
583                }
584            } else {
585                Err(CompletionError::ProviderError(
586                    String::from_utf8_lossy(&response_body).to_string(),
587                ))
588            }
589        }
590        .instrument(span)
591        .await
592    }
593
594    async fn stream(
595        &self,
596        completion_request: CompletionRequest,
597    ) -> Result<
598        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
599        CompletionError,
600    > {
601        let preamble = completion_request.preamble.clone();
602        let mut request =
603            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
604
605        let params = json_utils::merge(
606            request.additional_params.unwrap_or(serde_json::json!({})),
607            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
608        );
609
610        request.additional_params = Some(params);
611
612        if enabled!(Level::TRACE) {
613            tracing::trace!(target: "rig::completions",
614                "DeepSeek streaming completion request: {}",
615                serde_json::to_string_pretty(&request)?
616            );
617        }
618
619        let body = serde_json::to_vec(&request)?;
620
621        let req = self
622            .client
623            .post("/chat/completions")?
624            .body(body)
625            .map_err(|e| CompletionError::HttpError(e.into()))?;
626
627        let span = if tracing::Span::current().is_disabled() {
628            info_span!(
629                target: "rig::completions",
630                "chat_streaming",
631                gen_ai.operation.name = "chat_streaming",
632                gen_ai.provider.name = "deepseek",
633                gen_ai.request.model = self.model,
634                gen_ai.system_instructions = preamble,
635                gen_ai.response.id = tracing::field::Empty,
636                gen_ai.response.model = tracing::field::Empty,
637                gen_ai.usage.output_tokens = tracing::field::Empty,
638                gen_ai.usage.input_tokens = tracing::field::Empty,
639            )
640        } else {
641            tracing::Span::current()
642        };
643
644        tracing::Instrument::instrument(
645            send_compatible_streaming_request(self.client.clone(), req),
646            span,
647        )
648        .await
649    }
650}
651
652#[derive(Deserialize, Debug)]
653pub struct StreamingDelta {
654    #[serde(default)]
655    content: Option<String>,
656    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
657    tool_calls: Vec<StreamingToolCall>,
658    reasoning_content: Option<String>,
659}
660
661#[derive(Deserialize, Debug)]
662struct StreamingChoice {
663    delta: StreamingDelta,
664}
665
666#[derive(Deserialize, Debug)]
667struct StreamingCompletionChunk {
668    choices: Vec<StreamingChoice>,
669    usage: Option<Usage>,
670}
671
672#[derive(Clone, Deserialize, Serialize, Debug)]
673pub struct StreamingCompletionResponse {
674    pub usage: Usage,
675}
676
677impl GetTokenUsage for StreamingCompletionResponse {
678    fn token_usage(&self) -> Option<crate::completion::Usage> {
679        let mut usage = crate::completion::Usage::new();
680        usage.input_tokens = self.usage.prompt_tokens as u64;
681        usage.output_tokens = self.usage.completion_tokens as u64;
682        usage.total_tokens = self.usage.total_tokens as u64;
683
684        Some(usage)
685    }
686}
687
688pub async fn send_compatible_streaming_request<T>(
689    http_client: T,
690    req: Request<Vec<u8>>,
691) -> Result<
692    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
693    CompletionError,
694>
695where
696    T: HttpClientExt + Clone + 'static,
697{
698    let span = tracing::Span::current();
699    let mut event_source = GenericEventSource::new(http_client, req);
700
701    let stream = stream! {
702        let mut final_usage = Usage::new();
703        let mut text_response = String::new();
704        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
705
706        while let Some(event_result) = event_source.next().await {
707            match event_result {
708                Ok(Event::Open) => {
709                    tracing::trace!("SSE connection opened");
710                    continue;
711                }
712                Ok(Event::Message(message)) => {
713                    if message.data.trim().is_empty() || message.data == "[DONE]" {
714                        continue;
715                    }
716
717                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
718                    let Ok(data) = parsed else {
719                        let err = parsed.unwrap_err();
720                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
721                        continue;
722                    };
723
724                    if let Some(choice) = data.choices.first() {
725                        let delta = &choice.delta;
726
727                        if !delta.tool_calls.is_empty() {
728                            for tool_call in &delta.tool_calls {
729                                let function = &tool_call.function;
730
731                                // Start of tool call
732                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
733                                    && empty_or_none(&function.arguments)
734                                {
735                                    let id = tool_call.id.clone().unwrap_or_default();
736                                    let name = function.name.clone().unwrap();
737                                    calls.insert(tool_call.index, (id, name, String::new()));
738                                }
739                                // Continuation of tool call
740                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
741                                    && let Some(arguments) = &function.arguments
742                                    && !arguments.is_empty()
743                                {
744                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
745                                        let combined = format!("{}{}", existing_args, arguments);
746                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
747                                    } else {
748                                        tracing::debug!("Partial tool call received but tool call was never started.");
749                                    }
750                                }
751                                // Complete tool call
752                                else {
753                                    let id = tool_call.id.clone().unwrap_or_default();
754                                    let name = function.name.clone().unwrap_or_default();
755                                    let arguments_str = function.arguments.clone().unwrap_or_default();
756
757                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
758                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
759                                        continue;
760                                    };
761
762                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
763                                        crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
764                                    ));
765                                }
766                            }
767                        }
768
769                        // DeepSeek-specific reasoning stream
770                        if let Some(content) = &delta.reasoning_content {
771                            yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
772                                id: None,
773                                reasoning: content.to_string()
774                            });
775                        }
776
777                        if let Some(content) = &delta.content {
778                            text_response += content;
779                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
780                        }
781                    }
782
783                    if let Some(usage) = data.usage {
784                        final_usage = usage.clone();
785                    }
786                }
787                Err(crate::http_client::Error::StreamEnded) => {
788                    break;
789                }
790                Err(err) => {
791                    tracing::error!(?err, "SSE error");
792                    yield Err(CompletionError::ResponseError(err.to_string()));
793                    break;
794                }
795            }
796        }
797
798        event_source.close();
799
800        let mut tool_calls = Vec::new();
801        // Flush accumulated tool calls
802        for (index, (id, name, arguments)) in calls {
803            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
804                continue;
805            };
806
807            tool_calls.push(ToolCall {
808                id: id.clone(),
809                index,
810                r#type: ToolType::Function,
811                function: Function {
812                    name: name.clone(),
813                    arguments: arguments_json.clone()
814                }
815            });
816            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
817                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
818            ));
819        }
820
821        let message = Message::Assistant {
822            content: text_response,
823            name: None,
824            tool_calls
825        };
826
827        span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
828
829        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
830            StreamingCompletionResponse { usage: final_usage.clone() }
831        ));
832    };
833
834    Ok(crate::streaming::StreamingCompletionResponse::stream(
835        Box::pin(stream),
836    ))
837}
838
839// ================================================================
840// DeepSeek Completion API
841// ================================================================
842pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
843pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
844
845// Tests
846#[cfg(test)]
847mod tests {
848
849    use super::*;
850
851    #[test]
852    fn test_deserialize_vec_choice() {
853        let data = r#"[{
854            "finish_reason": "stop",
855            "index": 0,
856            "logprobs": null,
857            "message":{"role":"assistant","content":"Hello, world!"}
858            }]"#;
859
860        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
861        assert_eq!(choices.len(), 1);
862        match &choices.first().unwrap().message {
863            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
864            _ => panic!("Expected assistant message"),
865        }
866    }
867
868    #[test]
869    fn test_deserialize_deepseek_response() {
870        let data = r#"{
871            "choices":[{
872                "finish_reason": "stop",
873                "index": 0,
874                "logprobs": null,
875                "message":{"role":"assistant","content":"Hello, world!"}
876            }],
877            "usage": {
878                "completion_tokens": 0,
879                "prompt_tokens": 0,
880                "prompt_cache_hit_tokens": 0,
881                "prompt_cache_miss_tokens": 0,
882                "total_tokens": 0
883            }
884        }"#;
885
886        let jd = &mut serde_json::Deserializer::from_str(data);
887        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
888        match result {
889            Ok(response) => match &response.choices.first().unwrap().message {
890                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
891                _ => panic!("Expected assistant message"),
892            },
893            Err(err) => {
894                panic!("Deserialization error at {}: {}", err.path(), err);
895            }
896        }
897    }
898
899    #[test]
900    fn test_deserialize_example_response() {
901        let data = r#"
902        {
903            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
904            "object": "chat.completion",
905            "created": 0,
906            "model": "deepseek-chat",
907            "choices": [
908                {
909                    "index": 0,
910                    "message": {
911                        "role": "assistant",
912                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
913                    },
914                    "logprobs": null,
915                    "finish_reason": "stop"
916                }
917            ],
918            "usage": {
919                "prompt_tokens": 13,
920                "completion_tokens": 32,
921                "total_tokens": 45,
922                "prompt_tokens_details": {
923                    "cached_tokens": 0
924                },
925                "prompt_cache_hit_tokens": 0,
926                "prompt_cache_miss_tokens": 13
927            },
928            "system_fingerprint": "fp_4b6881f2c5"
929        }
930        "#;
931        let jd = &mut serde_json::Deserializer::from_str(data);
932        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
933
934        match result {
935            Ok(response) => match &response.choices.first().unwrap().message {
936                Message::Assistant { content, .. } => assert_eq!(
937                    content,
938                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
939                ),
940                _ => panic!("Expected assistant message"),
941            },
942            Err(err) => {
943                panic!("Deserialization error at {}: {}", err.path(), err);
944            }
945        }
946    }
947
948    #[test]
949    fn test_serialize_deserialize_tool_call_message() {
950        let tool_call_choice_json = r#"
951            {
952              "finish_reason": "tool_calls",
953              "index": 0,
954              "logprobs": null,
955              "message": {
956                "content": "",
957                "role": "assistant",
958                "tool_calls": [
959                  {
960                    "function": {
961                      "arguments": "{\"x\":2,\"y\":5}",
962                      "name": "subtract"
963                    },
964                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
965                    "index": 0,
966                    "type": "function"
967                  }
968                ]
969              }
970            }
971        "#;
972
973        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
974
975        let expected_choice: Choice = Choice {
976            finish_reason: "tool_calls".to_string(),
977            index: 0,
978            logprobs: None,
979            message: Message::Assistant {
980                content: "".to_string(),
981                name: None,
982                tool_calls: vec![ToolCall {
983                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
984                    function: Function {
985                        name: "subtract".to_string(),
986                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
987                    },
988                    index: 0,
989                    r#type: ToolType::Function,
990                }],
991            },
992        };
993
994        assert_eq!(choice, expected_choice);
995    }
996}