rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51
52    const VERIFY_PATH: &'static str = "/user/balance";
53
54    fn build<H>(
55        _: &crate::client::ClientBuilder<
56            Self::Builder,
57            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58            H,
59        >,
60    ) -> http_client::Result<Self> {
61        Ok(Self)
62    }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Nothing;
68    type Transcription = Nothing;
69    #[cfg(feature = "image")]
70    type ImageGeneration = Nothing;
71    #[cfg(feature = "audio")]
72    type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78    type Output = DeepSeekExt;
79    type ApiKey = DeepSeekApiKey;
80
81    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88    type Input = DeepSeekApiKey;
89
90    // If you prefer the environment variable approach:
91    fn from_env() -> Self {
92        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93        Self::new(&api_key).unwrap()
94    }
95
96    fn from_val(input: Self::Input) -> Self {
97        Self::new(input).unwrap()
98    }
99}
100
101#[derive(Debug, Deserialize)]
102struct ApiErrorResponse {
103    message: String,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(untagged)]
108enum ApiResponse<T> {
109    Ok(T),
110    Err(ApiErrorResponse),
111}
112
113impl From<ApiErrorResponse> for CompletionError {
114    fn from(err: ApiErrorResponse) -> Self {
115        CompletionError::ProviderError(err.message)
116    }
117}
118
119/// The response shape from the DeepSeek API
120#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct CompletionResponse {
122    // We'll match the JSON:
123    pub choices: Vec<Choice>,
124    pub usage: Usage,
125    // you may want other fields
126}
127
128#[derive(Clone, Debug, Serialize, Deserialize, Default)]
129pub struct Usage {
130    pub completion_tokens: u32,
131    pub prompt_tokens: u32,
132    pub prompt_cache_hit_tokens: u32,
133    pub prompt_cache_miss_tokens: u32,
134    pub total_tokens: u32,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub completion_tokens_details: Option<CompletionTokensDetails>,
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub prompt_tokens_details: Option<PromptTokensDetails>,
139}
140
141impl Usage {
142    fn new() -> Self {
143        Self {
144            completion_tokens: 0,
145            prompt_tokens: 0,
146            prompt_cache_hit_tokens: 0,
147            prompt_cache_miss_tokens: 0,
148            total_tokens: 0,
149            completion_tokens_details: None,
150            prompt_tokens_details: None,
151        }
152    }
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize, Default)]
156pub struct CompletionTokensDetails {
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub reasoning_tokens: Option<u32>,
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct PromptTokensDetails {
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub cached_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
168pub struct Choice {
169    pub index: usize,
170    pub message: Message,
171    pub logprobs: Option<serde_json::Value>,
172    pub finish_reason: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(tag = "role", rename_all = "lowercase")]
177pub enum Message {
178    System {
179        content: String,
180        #[serde(skip_serializing_if = "Option::is_none")]
181        name: Option<String>,
182    },
183    User {
184        content: String,
185        #[serde(skip_serializing_if = "Option::is_none")]
186        name: Option<String>,
187    },
188    Assistant {
189        content: String,
190        #[serde(skip_serializing_if = "Option::is_none")]
191        name: Option<String>,
192        #[serde(
193            default,
194            deserialize_with = "json_utils::null_or_vec",
195            skip_serializing_if = "Vec::is_empty"
196        )]
197        tool_calls: Vec<ToolCall>,
198    },
199    #[serde(rename = "tool")]
200    ToolResult {
201        tool_call_id: String,
202        content: String,
203    },
204}
205
206impl Message {
207    pub fn system(content: &str) -> Self {
208        Message::System {
209            content: content.to_owned(),
210            name: None,
211        }
212    }
213}
214
215impl From<message::ToolResult> for Message {
216    fn from(tool_result: message::ToolResult) -> Self {
217        let content = match tool_result.content.first() {
218            message::ToolResultContent::Text(text) => text.text,
219            message::ToolResultContent::Image(_) => String::from("[Image]"),
220        };
221
222        Message::ToolResult {
223            tool_call_id: tool_result.id,
224            content,
225        }
226    }
227}
228
229impl From<message::ToolCall> for ToolCall {
230    fn from(tool_call: message::ToolCall) -> Self {
231        Self {
232            id: tool_call.id,
233            // TODO: update index when we have it
234            index: 0,
235            r#type: ToolType::Function,
236            function: Function {
237                name: tool_call.function.name,
238                arguments: tool_call.function.arguments,
239            },
240        }
241    }
242}
243
244impl TryFrom<message::Message> for Vec<Message> {
245    type Error = message::MessageError;
246
247    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248        match message {
249            message::Message::User { content } => {
250                // extract tool results
251                let mut messages = vec![];
252
253                let tool_results = content
254                    .clone()
255                    .into_iter()
256                    .filter_map(|content| match content {
257                        message::UserContent::ToolResult(tool_result) => {
258                            Some(Message::from(tool_result))
259                        }
260                        _ => None,
261                    })
262                    .collect::<Vec<_>>();
263
264                messages.extend(tool_results);
265
266                // extract text results
267                let text_messages = content
268                    .into_iter()
269                    .filter_map(|content| match content {
270                        message::UserContent::Text(text) => Some(Message::User {
271                            content: text.text,
272                            name: None,
273                        }),
274                        message::UserContent::Document(Document {
275                            data:
276                                DocumentSourceKind::Base64(content)
277                                | DocumentSourceKind::String(content),
278                            ..
279                        }) => Some(Message::User {
280                            content,
281                            name: None,
282                        }),
283                        _ => None,
284                    })
285                    .collect::<Vec<_>>();
286                messages.extend(text_messages);
287
288                Ok(messages)
289            }
290            message::Message::Assistant { content, .. } => {
291                let mut messages: Vec<Message> = vec![];
292
293                // extract text
294                let text_content = content
295                    .clone()
296                    .into_iter()
297                    .filter_map(|content| match content {
298                        message::AssistantContent::Text(text) => Some(Message::Assistant {
299                            content: text.text,
300                            name: None,
301                            tool_calls: vec![],
302                        }),
303                        _ => None,
304                    })
305                    .collect::<Vec<_>>();
306
307                messages.extend(text_content);
308
309                // extract tool calls
310                let tool_calls = content
311                    .clone()
312                    .into_iter()
313                    .filter_map(|content| match content {
314                        message::AssistantContent::ToolCall(tool_call) => {
315                            Some(ToolCall::from(tool_call))
316                        }
317                        _ => None,
318                    })
319                    .collect::<Vec<_>>();
320
321                // if we have tool calls, we add a new Assistant message with them
322                if !tool_calls.is_empty() {
323                    messages.push(Message::Assistant {
324                        content: "".to_string(),
325                        name: None,
326                        tool_calls,
327                    });
328                }
329
330                Ok(messages)
331            }
332        }
333    }
334}
335
336#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
337pub struct ToolCall {
338    pub id: String,
339    pub index: usize,
340    #[serde(default)]
341    pub r#type: ToolType,
342    pub function: Function,
343}
344
345#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
346pub struct Function {
347    pub name: String,
348    #[serde(with = "json_utils::stringified_json")]
349    pub arguments: serde_json::Value,
350}
351
352#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
353#[serde(rename_all = "lowercase")]
354pub enum ToolType {
355    #[default]
356    Function,
357}
358
359#[derive(Clone, Debug, Deserialize, Serialize)]
360pub struct ToolDefinition {
361    pub r#type: String,
362    pub function: completion::ToolDefinition,
363}
364
365impl From<crate::completion::ToolDefinition> for ToolDefinition {
366    fn from(tool: crate::completion::ToolDefinition) -> Self {
367        Self {
368            r#type: "function".into(),
369            function: tool,
370        }
371    }
372}
373
374impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
375    type Error = CompletionError;
376
377    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
378        let choice = response.choices.first().ok_or_else(|| {
379            CompletionError::ResponseError("Response contained no choices".to_owned())
380        })?;
381        let content = match &choice.message {
382            Message::Assistant {
383                content,
384                tool_calls,
385                ..
386            } => {
387                let mut content = if content.trim().is_empty() {
388                    vec![]
389                } else {
390                    vec![completion::AssistantContent::text(content)]
391                };
392
393                content.extend(
394                    tool_calls
395                        .iter()
396                        .map(|call| {
397                            completion::AssistantContent::tool_call(
398                                &call.id,
399                                &call.function.name,
400                                call.function.arguments.clone(),
401                            )
402                        })
403                        .collect::<Vec<_>>(),
404                );
405                Ok(content)
406            }
407            _ => Err(CompletionError::ResponseError(
408                "Response did not contain a valid message or tool call".into(),
409            )),
410        }?;
411
412        let choice = OneOrMany::many(content).map_err(|_| {
413            CompletionError::ResponseError(
414                "Response contained no message or tool call (empty)".to_owned(),
415            )
416        })?;
417
418        let usage = completion::Usage {
419            input_tokens: response.usage.prompt_tokens as u64,
420            output_tokens: response.usage.completion_tokens as u64,
421            total_tokens: response.usage.total_tokens as u64,
422        };
423
424        Ok(completion::CompletionResponse {
425            choice,
426            usage,
427            raw_response: response,
428        })
429    }
430}
431
432#[derive(Debug, Serialize, Deserialize)]
433pub(super) struct DeepseekCompletionRequest {
434    model: String,
435    pub messages: Vec<Message>,
436    #[serde(flatten, skip_serializing_if = "Option::is_none")]
437    temperature: Option<f64>,
438    #[serde(skip_serializing_if = "Vec::is_empty")]
439    tools: Vec<ToolDefinition>,
440    #[serde(flatten, skip_serializing_if = "Option::is_none")]
441    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
442    #[serde(flatten, skip_serializing_if = "Option::is_none")]
443    pub additional_params: Option<serde_json::Value>,
444}
445
446impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
447    type Error = CompletionError;
448
449    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
450        let mut full_history: Vec<Message> = match &req.preamble {
451            Some(preamble) => vec![Message::system(preamble)],
452            None => vec![],
453        };
454
455        if let Some(docs) = req.normalized_documents() {
456            let docs: Vec<Message> = docs.try_into()?;
457            full_history.extend(docs);
458        }
459
460        let chat_history: Vec<Message> = req
461            .chat_history
462            .clone()
463            .into_iter()
464            .map(|message| message.try_into())
465            .collect::<Result<Vec<Vec<Message>>, _>>()?
466            .into_iter()
467            .flatten()
468            .collect();
469
470        full_history.extend(chat_history);
471
472        let tool_choice = req
473            .tool_choice
474            .clone()
475            .map(crate::providers::openrouter::ToolChoice::try_from)
476            .transpose()?;
477
478        Ok(Self {
479            model: model.to_string(),
480            messages: full_history,
481            temperature: req.temperature,
482            tools: req
483                .tools
484                .clone()
485                .into_iter()
486                .map(ToolDefinition::from)
487                .collect::<Vec<_>>(),
488            tool_choice,
489            additional_params: req.additional_params,
490        })
491    }
492}
493
494/// The struct implementing the `CompletionModel` trait
495#[derive(Clone)]
496pub struct CompletionModel<T = reqwest::Client> {
497    pub client: Client<T>,
498    pub model: String,
499}
500
501impl<T> completion::CompletionModel for CompletionModel<T>
502where
503    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
504{
505    type Response = CompletionResponse;
506    type StreamingResponse = StreamingCompletionResponse;
507
508    type Client = Client<T>;
509
510    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
511        Self {
512            client: client.clone(),
513            model: model.into().to_string(),
514        }
515    }
516
517    #[cfg_attr(feature = "worker", worker::send)]
518    async fn completion(
519        &self,
520        completion_request: CompletionRequest,
521    ) -> Result<
522        completion::CompletionResponse<CompletionResponse>,
523        crate::completion::CompletionError,
524    > {
525        let preamble = completion_request.preamble.clone();
526        let request =
527            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
528
529        let span = if tracing::Span::current().is_disabled() {
530            info_span!(
531                target: "rig::completions",
532                "chat",
533                gen_ai.operation.name = "chat",
534                gen_ai.provider.name = "deepseek",
535                gen_ai.request.model = self.model,
536                gen_ai.system_instructions = preamble,
537                gen_ai.response.id = tracing::field::Empty,
538                gen_ai.response.model = tracing::field::Empty,
539                gen_ai.usage.output_tokens = tracing::field::Empty,
540                gen_ai.usage.input_tokens = tracing::field::Empty,
541                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
542                gen_ai.output.messages = tracing::field::Empty,
543            )
544        } else {
545            tracing::Span::current()
546        };
547
548        tracing::debug!("DeepSeek completion request: {request:?}");
549
550        let body = serde_json::to_vec(&request)?;
551        let req = self
552            .client
553            .post("/chat/completions")?
554            .body(body)
555            .map_err(|e| CompletionError::HttpError(e.into()))?;
556
557        async move {
558            let response = self.client.send::<_, Bytes>(req).await?;
559            let status = response.status();
560            let response_body = response.into_body().into_future().await?.to_vec();
561
562            if status.is_success() {
563                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
564                    ApiResponse::Ok(response) => {
565                        let span = tracing::Span::current();
566                        span.record(
567                            "gen_ai.output.messages",
568                            serde_json::to_string(&response.choices).unwrap(),
569                        );
570                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
571                        span.record(
572                            "gen_ai.usage.output_tokens",
573                            response.usage.completion_tokens,
574                        );
575                        tracing::trace!(
576                            target: "rig::completions",
577                            "DeepSeek completion output: {}",
578                            serde_json::to_string_pretty(&response_body)?
579                        );
580                        response.try_into()
581                    }
582                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
583                }
584            } else {
585                Err(CompletionError::ProviderError(
586                    String::from_utf8_lossy(&response_body).to_string(),
587                ))
588            }
589        }
590        .instrument(span)
591        .await
592    }
593
594    #[cfg_attr(feature = "worker", worker::send)]
595    async fn stream(
596        &self,
597        completion_request: CompletionRequest,
598    ) -> Result<
599        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
600        CompletionError,
601    > {
602        let preamble = completion_request.preamble.clone();
603        let mut request =
604            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
605
606        let params = json_utils::merge(
607            request.additional_params.unwrap_or(serde_json::json!({})),
608            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
609        );
610
611        request.additional_params = Some(params);
612
613        let body = serde_json::to_vec(&request)?;
614
615        let req = self
616            .client
617            .post("/chat/completions")?
618            .body(body)
619            .map_err(|e| CompletionError::HttpError(e.into()))?;
620
621        let span = if tracing::Span::current().is_disabled() {
622            info_span!(
623                target: "rig::completions",
624                "chat_streaming",
625                gen_ai.operation.name = "chat_streaming",
626                gen_ai.provider.name = "deepseek",
627                gen_ai.request.model = self.model,
628                gen_ai.system_instructions = preamble,
629                gen_ai.response.id = tracing::field::Empty,
630                gen_ai.response.model = tracing::field::Empty,
631                gen_ai.usage.output_tokens = tracing::field::Empty,
632                gen_ai.usage.input_tokens = tracing::field::Empty,
633                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
634                gen_ai.output.messages = tracing::field::Empty,
635            )
636        } else {
637            tracing::Span::current()
638        };
639
640        tracing::Instrument::instrument(
641            send_compatible_streaming_request(self.client.http_client().clone(), req),
642            span,
643        )
644        .await
645    }
646}
647
648#[derive(Deserialize, Debug)]
649pub struct StreamingDelta {
650    #[serde(default)]
651    content: Option<String>,
652    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
653    tool_calls: Vec<StreamingToolCall>,
654    reasoning_content: Option<String>,
655}
656
657#[derive(Deserialize, Debug)]
658struct StreamingChoice {
659    delta: StreamingDelta,
660}
661
662#[derive(Deserialize, Debug)]
663struct StreamingCompletionChunk {
664    choices: Vec<StreamingChoice>,
665    usage: Option<Usage>,
666}
667
668#[derive(Clone, Deserialize, Serialize, Debug)]
669pub struct StreamingCompletionResponse {
670    pub usage: Usage,
671}
672
673impl GetTokenUsage for StreamingCompletionResponse {
674    fn token_usage(&self) -> Option<crate::completion::Usage> {
675        let mut usage = crate::completion::Usage::new();
676        usage.input_tokens = self.usage.prompt_tokens as u64;
677        usage.output_tokens = self.usage.completion_tokens as u64;
678        usage.total_tokens = self.usage.total_tokens as u64;
679
680        Some(usage)
681    }
682}
683
684pub async fn send_compatible_streaming_request<T>(
685    http_client: T,
686    req: Request<Vec<u8>>,
687) -> Result<
688    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
689    CompletionError,
690>
691where
692    T: HttpClientExt + Clone + 'static,
693{
694    let span = tracing::Span::current();
695    let mut event_source = GenericEventSource::new(http_client, req);
696
697    let stream = stream! {
698        let mut final_usage = Usage::new();
699        let mut text_response = String::new();
700        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
701
702        while let Some(event_result) = event_source.next().await {
703            match event_result {
704                Ok(Event::Open) => {
705                    tracing::trace!("SSE connection opened");
706                    continue;
707                }
708                Ok(Event::Message(message)) => {
709                    if message.data.trim().is_empty() || message.data == "[DONE]" {
710                        continue;
711                    }
712
713                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
714                    let Ok(data) = parsed else {
715                        let err = parsed.unwrap_err();
716                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
717                        continue;
718                    };
719
720                    if let Some(choice) = data.choices.first() {
721                        let delta = &choice.delta;
722
723                        if !delta.tool_calls.is_empty() {
724                            for tool_call in &delta.tool_calls {
725                                let function = &tool_call.function;
726
727                                // Start of tool call
728                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
729                                    && empty_or_none(&function.arguments)
730                                {
731                                    let id = tool_call.id.clone().unwrap_or_default();
732                                    let name = function.name.clone().unwrap();
733                                    calls.insert(tool_call.index, (id, name, String::new()));
734                                }
735                                // Continuation of tool call
736                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
737                                    && let Some(arguments) = &function.arguments
738                                    && !arguments.is_empty()
739                                {
740                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
741                                        let combined = format!("{}{}", existing_args, arguments);
742                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
743                                    } else {
744                                        tracing::debug!("Partial tool call received but tool call was never started.");
745                                    }
746                                }
747                                // Complete tool call
748                                else {
749                                    let id = tool_call.id.clone().unwrap_or_default();
750                                    let name = function.name.clone().unwrap_or_default();
751                                    let arguments_str = function.arguments.clone().unwrap_or_default();
752
753                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
754                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
755                                        continue;
756                                    };
757
758                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
759                                        id,
760                                        name,
761                                        arguments: arguments_json,
762                                        call_id: None,
763                                    });
764                                }
765                            }
766                        }
767
768                        // DeepSeek-specific reasoning stream
769                        if let Some(content) = &delta.reasoning_content {
770                            yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
771                                reasoning: content.to_string(),
772                                id: None,
773                                signature: None,
774                            });
775                        }
776
777                        if let Some(content) = &delta.content {
778                            text_response += content;
779                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
780                        }
781                    }
782
783                    if let Some(usage) = data.usage {
784                        final_usage = usage.clone();
785                    }
786                }
787                Err(crate::http_client::Error::StreamEnded) => {
788                    break;
789                }
790                Err(err) => {
791                    tracing::error!(?err, "SSE error");
792                    yield Err(CompletionError::ResponseError(err.to_string()));
793                    break;
794                }
795            }
796        }
797
798        event_source.close();
799
800        let mut tool_calls = Vec::new();
801        // Flush accumulated tool calls
802        for (index, (id, name, arguments)) in calls {
803            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
804                continue;
805            };
806
807            tool_calls.push(ToolCall {
808                id: id.clone(),
809                index,
810                r#type: ToolType::Function,
811                function: Function {
812                    name: name.clone(),
813                    arguments: arguments_json.clone()
814                }
815            });
816            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
817                id,
818                name,
819                arguments: arguments_json,
820                call_id: None,
821            });
822        }
823
824        let message = Message::Assistant {
825            content: text_response,
826            name: None,
827            tool_calls
828        };
829
830        span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
831
832        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
833            StreamingCompletionResponse { usage: final_usage.clone() }
834        ));
835    };
836
837    Ok(crate::streaming::StreamingCompletionResponse::stream(
838        Box::pin(stream),
839    ))
840}
841
842// ================================================================
843// DeepSeek Completion API
844// ================================================================
845pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
846pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
847
848// Tests
849#[cfg(test)]
850mod tests {
851
852    use super::*;
853
854    #[test]
855    fn test_deserialize_vec_choice() {
856        let data = r#"[{
857            "finish_reason": "stop",
858            "index": 0,
859            "logprobs": null,
860            "message":{"role":"assistant","content":"Hello, world!"}
861            }]"#;
862
863        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
864        assert_eq!(choices.len(), 1);
865        match &choices.first().unwrap().message {
866            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
867            _ => panic!("Expected assistant message"),
868        }
869    }
870
871    #[test]
872    fn test_deserialize_deepseek_response() {
873        let data = r#"{
874            "choices":[{
875                "finish_reason": "stop",
876                "index": 0,
877                "logprobs": null,
878                "message":{"role":"assistant","content":"Hello, world!"}
879            }],
880            "usage": {
881                "completion_tokens": 0,
882                "prompt_tokens": 0,
883                "prompt_cache_hit_tokens": 0,
884                "prompt_cache_miss_tokens": 0,
885                "total_tokens": 0
886            }
887        }"#;
888
889        let jd = &mut serde_json::Deserializer::from_str(data);
890        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
891        match result {
892            Ok(response) => match &response.choices.first().unwrap().message {
893                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
894                _ => panic!("Expected assistant message"),
895            },
896            Err(err) => {
897                panic!("Deserialization error at {}: {}", err.path(), err);
898            }
899        }
900    }
901
902    #[test]
903    fn test_deserialize_example_response() {
904        let data = r#"
905        {
906            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
907            "object": "chat.completion",
908            "created": 0,
909            "model": "deepseek-chat",
910            "choices": [
911                {
912                    "index": 0,
913                    "message": {
914                        "role": "assistant",
915                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
916                    },
917                    "logprobs": null,
918                    "finish_reason": "stop"
919                }
920            ],
921            "usage": {
922                "prompt_tokens": 13,
923                "completion_tokens": 32,
924                "total_tokens": 45,
925                "prompt_tokens_details": {
926                    "cached_tokens": 0
927                },
928                "prompt_cache_hit_tokens": 0,
929                "prompt_cache_miss_tokens": 13
930            },
931            "system_fingerprint": "fp_4b6881f2c5"
932        }
933        "#;
934        let jd = &mut serde_json::Deserializer::from_str(data);
935        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
936
937        match result {
938            Ok(response) => match &response.choices.first().unwrap().message {
939                Message::Assistant { content, .. } => assert_eq!(
940                    content,
941                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
942                ),
943                _ => panic!("Expected assistant message"),
944            },
945            Err(err) => {
946                panic!("Deserialization error at {}: {}", err.path(), err);
947            }
948        }
949    }
950
951    #[test]
952    fn test_serialize_deserialize_tool_call_message() {
953        let tool_call_choice_json = r#"
954            {
955              "finish_reason": "tool_calls",
956              "index": 0,
957              "logprobs": null,
958              "message": {
959                "content": "",
960                "role": "assistant",
961                "tool_calls": [
962                  {
963                    "function": {
964                      "arguments": "{\"x\":2,\"y\":5}",
965                      "name": "subtract"
966                    },
967                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
968                    "index": 0,
969                    "type": "function"
970                  }
971                ]
972              }
973            }
974        "#;
975
976        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
977
978        let expected_choice: Choice = Choice {
979            finish_reason: "tool_calls".to_string(),
980            index: 0,
981            logprobs: None,
982            message: Message::Assistant {
983                content: "".to_string(),
984                name: None,
985                tool_calls: vec![ToolCall {
986                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
987                    function: Function {
988                        name: "subtract".to_string(),
989                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
990                    },
991                    index: 0,
992                    r#type: ToolType::Function,
993                }],
994            },
995        };
996
997        assert_eq!(choice, expected_choice);
998    }
999}