rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51
52    const VERIFY_PATH: &'static str = "/user/balance";
53
54    fn build<H>(
55        _: &crate::client::ClientBuilder<
56            Self::Builder,
57            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58            H,
59        >,
60    ) -> http_client::Result<Self> {
61        Ok(Self)
62    }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Nothing;
68    type Transcription = Nothing;
69    #[cfg(feature = "image")]
70    type ImageGeneration = Nothing;
71    #[cfg(feature = "audio")]
72    type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78    type Output = DeepSeekExt;
79    type ApiKey = DeepSeekApiKey;
80
81    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88    type Input = DeepSeekApiKey;
89
90    // If you prefer the environment variable approach:
91    fn from_env() -> Self {
92        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93        Self::new(&api_key).unwrap()
94    }
95
96    fn from_val(input: Self::Input) -> Self {
97        Self::new(input).unwrap()
98    }
99}
100
101#[derive(Debug, Deserialize)]
102struct ApiErrorResponse {
103    message: String,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(untagged)]
108enum ApiResponse<T> {
109    Ok(T),
110    Err(ApiErrorResponse),
111}
112
113impl From<ApiErrorResponse> for CompletionError {
114    fn from(err: ApiErrorResponse) -> Self {
115        CompletionError::ProviderError(err.message)
116    }
117}
118
119/// The response shape from the DeepSeek API
120#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct CompletionResponse {
122    // We'll match the JSON:
123    pub choices: Vec<Choice>,
124    pub usage: Usage,
125    // you may want other fields
126}
127
128#[derive(Clone, Debug, Serialize, Deserialize, Default)]
129pub struct Usage {
130    pub completion_tokens: u32,
131    pub prompt_tokens: u32,
132    pub prompt_cache_hit_tokens: u32,
133    pub prompt_cache_miss_tokens: u32,
134    pub total_tokens: u32,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub completion_tokens_details: Option<CompletionTokensDetails>,
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub prompt_tokens_details: Option<PromptTokensDetails>,
139}
140
141impl Usage {
142    fn new() -> Self {
143        Self {
144            completion_tokens: 0,
145            prompt_tokens: 0,
146            prompt_cache_hit_tokens: 0,
147            prompt_cache_miss_tokens: 0,
148            total_tokens: 0,
149            completion_tokens_details: None,
150            prompt_tokens_details: None,
151        }
152    }
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize, Default)]
156pub struct CompletionTokensDetails {
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub reasoning_tokens: Option<u32>,
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct PromptTokensDetails {
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub cached_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
168pub struct Choice {
169    pub index: usize,
170    pub message: Message,
171    pub logprobs: Option<serde_json::Value>,
172    pub finish_reason: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(tag = "role", rename_all = "lowercase")]
177pub enum Message {
178    System {
179        content: String,
180        #[serde(skip_serializing_if = "Option::is_none")]
181        name: Option<String>,
182    },
183    User {
184        content: String,
185        #[serde(skip_serializing_if = "Option::is_none")]
186        name: Option<String>,
187    },
188    Assistant {
189        content: String,
190        #[serde(skip_serializing_if = "Option::is_none")]
191        name: Option<String>,
192        #[serde(
193            default,
194            deserialize_with = "json_utils::null_or_vec",
195            skip_serializing_if = "Vec::is_empty"
196        )]
197        tool_calls: Vec<ToolCall>,
198    },
199    #[serde(rename = "tool")]
200    ToolResult {
201        tool_call_id: String,
202        content: String,
203    },
204}
205
206impl Message {
207    pub fn system(content: &str) -> Self {
208        Message::System {
209            content: content.to_owned(),
210            name: None,
211        }
212    }
213}
214
215impl From<message::ToolResult> for Message {
216    fn from(tool_result: message::ToolResult) -> Self {
217        let content = match tool_result.content.first() {
218            message::ToolResultContent::Text(text) => text.text,
219            message::ToolResultContent::Image(_) => String::from("[Image]"),
220        };
221
222        Message::ToolResult {
223            tool_call_id: tool_result.id,
224            content,
225        }
226    }
227}
228
229impl From<message::ToolCall> for ToolCall {
230    fn from(tool_call: message::ToolCall) -> Self {
231        Self {
232            id: tool_call.id,
233            // TODO: update index when we have it
234            index: 0,
235            r#type: ToolType::Function,
236            function: Function {
237                name: tool_call.function.name,
238                arguments: tool_call.function.arguments,
239            },
240        }
241    }
242}
243
244impl TryFrom<message::Message> for Vec<Message> {
245    type Error = message::MessageError;
246
247    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248        match message {
249            message::Message::User { content } => {
250                // extract tool results
251                let mut messages = vec![];
252
253                let tool_results = content
254                    .clone()
255                    .into_iter()
256                    .filter_map(|content| match content {
257                        message::UserContent::ToolResult(tool_result) => {
258                            Some(Message::from(tool_result))
259                        }
260                        _ => None,
261                    })
262                    .collect::<Vec<_>>();
263
264                messages.extend(tool_results);
265
266                // extract text results
267                let text_messages = content
268                    .into_iter()
269                    .filter_map(|content| match content {
270                        message::UserContent::Text(text) => Some(Message::User {
271                            content: text.text,
272                            name: None,
273                        }),
274                        message::UserContent::Document(Document {
275                            data:
276                                DocumentSourceKind::Base64(content)
277                                | DocumentSourceKind::String(content),
278                            ..
279                        }) => Some(Message::User {
280                            content,
281                            name: None,
282                        }),
283                        _ => None,
284                    })
285                    .collect::<Vec<_>>();
286                messages.extend(text_messages);
287
288                Ok(messages)
289            }
290            message::Message::Assistant { content, .. } => {
291                let mut messages: Vec<Message> = vec![];
292
293                // extract text
294                let text_content = content
295                    .clone()
296                    .into_iter()
297                    .filter_map(|content| match content {
298                        message::AssistantContent::Text(text) => Some(Message::Assistant {
299                            content: text.text,
300                            name: None,
301                            tool_calls: vec![],
302                        }),
303                        _ => None,
304                    })
305                    .collect::<Vec<_>>();
306
307                messages.extend(text_content);
308
309                // extract tool calls
310                let tool_calls = content
311                    .clone()
312                    .into_iter()
313                    .filter_map(|content| match content {
314                        message::AssistantContent::ToolCall(tool_call) => {
315                            Some(ToolCall::from(tool_call))
316                        }
317                        _ => None,
318                    })
319                    .collect::<Vec<_>>();
320
321                // if we have tool calls, we add a new Assistant message with them
322                if !tool_calls.is_empty() {
323                    messages.push(Message::Assistant {
324                        content: "".to_string(),
325                        name: None,
326                        tool_calls,
327                    });
328                }
329
330                Ok(messages)
331            }
332        }
333    }
334}
335
336#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
337pub struct ToolCall {
338    pub id: String,
339    pub index: usize,
340    #[serde(default)]
341    pub r#type: ToolType,
342    pub function: Function,
343}
344
345#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
346pub struct Function {
347    pub name: String,
348    #[serde(with = "json_utils::stringified_json")]
349    pub arguments: serde_json::Value,
350}
351
352#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
353#[serde(rename_all = "lowercase")]
354pub enum ToolType {
355    #[default]
356    Function,
357}
358
359#[derive(Clone, Debug, Deserialize, Serialize)]
360pub struct ToolDefinition {
361    pub r#type: String,
362    pub function: completion::ToolDefinition,
363}
364
365impl From<crate::completion::ToolDefinition> for ToolDefinition {
366    fn from(tool: crate::completion::ToolDefinition) -> Self {
367        Self {
368            r#type: "function".into(),
369            function: tool,
370        }
371    }
372}
373
374impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
375    type Error = CompletionError;
376
377    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
378        let choice = response.choices.first().ok_or_else(|| {
379            CompletionError::ResponseError("Response contained no choices".to_owned())
380        })?;
381        let content = match &choice.message {
382            Message::Assistant {
383                content,
384                tool_calls,
385                ..
386            } => {
387                let mut content = if content.trim().is_empty() {
388                    vec![]
389                } else {
390                    vec![completion::AssistantContent::text(content)]
391                };
392
393                content.extend(
394                    tool_calls
395                        .iter()
396                        .map(|call| {
397                            completion::AssistantContent::tool_call(
398                                &call.id,
399                                &call.function.name,
400                                call.function.arguments.clone(),
401                            )
402                        })
403                        .collect::<Vec<_>>(),
404                );
405                Ok(content)
406            }
407            _ => Err(CompletionError::ResponseError(
408                "Response did not contain a valid message or tool call".into(),
409            )),
410        }?;
411
412        let choice = OneOrMany::many(content).map_err(|_| {
413            CompletionError::ResponseError(
414                "Response contained no message or tool call (empty)".to_owned(),
415            )
416        })?;
417
418        let usage = completion::Usage {
419            input_tokens: response.usage.prompt_tokens as u64,
420            output_tokens: response.usage.completion_tokens as u64,
421            total_tokens: response.usage.total_tokens as u64,
422        };
423
424        Ok(completion::CompletionResponse {
425            choice,
426            usage,
427            raw_response: response,
428        })
429    }
430}
431
432#[derive(Debug, Serialize, Deserialize)]
433pub(super) struct DeepseekCompletionRequest {
434    model: String,
435    pub messages: Vec<Message>,
436    #[serde(flatten, skip_serializing_if = "Option::is_none")]
437    temperature: Option<f64>,
438    #[serde(skip_serializing_if = "Vec::is_empty")]
439    tools: Vec<ToolDefinition>,
440    #[serde(flatten, skip_serializing_if = "Option::is_none")]
441    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
442    #[serde(flatten, skip_serializing_if = "Option::is_none")]
443    pub additional_params: Option<serde_json::Value>,
444}
445
446impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
447    type Error = CompletionError;
448
449    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
450        let mut full_history: Vec<Message> = match &req.preamble {
451            Some(preamble) => vec![Message::system(preamble)],
452            None => vec![],
453        };
454
455        if let Some(docs) = req.normalized_documents() {
456            let docs: Vec<Message> = docs.try_into()?;
457            full_history.extend(docs);
458        }
459
460        let chat_history: Vec<Message> = req
461            .chat_history
462            .clone()
463            .into_iter()
464            .map(|message| message.try_into())
465            .collect::<Result<Vec<Vec<Message>>, _>>()?
466            .into_iter()
467            .flatten()
468            .collect();
469
470        full_history.extend(chat_history);
471
472        let tool_choice = req
473            .tool_choice
474            .clone()
475            .map(crate::providers::openrouter::ToolChoice::try_from)
476            .transpose()?;
477
478        Ok(Self {
479            model: model.to_string(),
480            messages: full_history,
481            temperature: req.temperature,
482            tools: req
483                .tools
484                .clone()
485                .into_iter()
486                .map(ToolDefinition::from)
487                .collect::<Vec<_>>(),
488            tool_choice,
489            additional_params: req.additional_params,
490        })
491    }
492}
493
494/// The struct implementing the `CompletionModel` trait
495#[derive(Clone)]
496pub struct CompletionModel<T = reqwest::Client> {
497    pub client: Client<T>,
498    pub model: String,
499}
500
501impl<T> completion::CompletionModel for CompletionModel<T>
502where
503    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
504{
505    type Response = CompletionResponse;
506    type StreamingResponse = StreamingCompletionResponse;
507
508    type Client = Client<T>;
509
510    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
511        Self {
512            client: client.clone(),
513            model: model.into().to_string(),
514        }
515    }
516
517    #[cfg_attr(feature = "worker", worker::send)]
518    async fn completion(
519        &self,
520        completion_request: CompletionRequest,
521    ) -> Result<
522        completion::CompletionResponse<CompletionResponse>,
523        crate::completion::CompletionError,
524    > {
525        let span = if tracing::Span::current().is_disabled() {
526            info_span!(
527                target: "rig::completions",
528                "chat",
529                gen_ai.operation.name = "chat",
530                gen_ai.provider.name = "deepseek",
531                gen_ai.request.model = self.model,
532                gen_ai.system_instructions = tracing::field::Empty,
533                gen_ai.response.id = tracing::field::Empty,
534                gen_ai.response.model = tracing::field::Empty,
535                gen_ai.usage.output_tokens = tracing::field::Empty,
536                gen_ai.usage.input_tokens = tracing::field::Empty,
537            )
538        } else {
539            tracing::Span::current()
540        };
541
542        span.record("gen_ai.system_instructions", &completion_request.preamble);
543
544        let request =
545            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
546
547        if enabled!(Level::TRACE) {
548            tracing::trace!(target: "rig::completions",
549                "DeepSeek completion request: {}",
550                serde_json::to_string_pretty(&request)?
551            );
552        }
553
554        let body = serde_json::to_vec(&request)?;
555        let req = self
556            .client
557            .post("/chat/completions")?
558            .body(body)
559            .map_err(|e| CompletionError::HttpError(e.into()))?;
560
561        async move {
562            let response = self.client.send::<_, Bytes>(req).await?;
563            let status = response.status();
564            let response_body = response.into_body().into_future().await?.to_vec();
565
566            if status.is_success() {
567                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
568                    ApiResponse::Ok(response) => {
569                        let span = tracing::Span::current();
570                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
571                        span.record(
572                            "gen_ai.usage.output_tokens",
573                            response.usage.completion_tokens,
574                        );
575                        if enabled!(Level::TRACE) {
576                            tracing::trace!(target: "rig::completions",
577                                "DeepSeek completion response: {}",
578                                serde_json::to_string_pretty(&response)?
579                            );
580                        }
581                        response.try_into()
582                    }
583                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
584                }
585            } else {
586                Err(CompletionError::ProviderError(
587                    String::from_utf8_lossy(&response_body).to_string(),
588                ))
589            }
590        }
591        .instrument(span)
592        .await
593    }
594
595    #[cfg_attr(feature = "worker", worker::send)]
596    async fn stream(
597        &self,
598        completion_request: CompletionRequest,
599    ) -> Result<
600        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
601        CompletionError,
602    > {
603        let preamble = completion_request.preamble.clone();
604        let mut request =
605            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
606
607        let params = json_utils::merge(
608            request.additional_params.unwrap_or(serde_json::json!({})),
609            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
610        );
611
612        request.additional_params = Some(params);
613
614        if enabled!(Level::TRACE) {
615            tracing::trace!(target: "rig::completions",
616                "DeepSeek streaming completion request: {}",
617                serde_json::to_string_pretty(&request)?
618            );
619        }
620
621        let body = serde_json::to_vec(&request)?;
622
623        let req = self
624            .client
625            .post("/chat/completions")?
626            .body(body)
627            .map_err(|e| CompletionError::HttpError(e.into()))?;
628
629        let span = if tracing::Span::current().is_disabled() {
630            info_span!(
631                target: "rig::completions",
632                "chat_streaming",
633                gen_ai.operation.name = "chat_streaming",
634                gen_ai.provider.name = "deepseek",
635                gen_ai.request.model = self.model,
636                gen_ai.system_instructions = preamble,
637                gen_ai.response.id = tracing::field::Empty,
638                gen_ai.response.model = tracing::field::Empty,
639                gen_ai.usage.output_tokens = tracing::field::Empty,
640                gen_ai.usage.input_tokens = tracing::field::Empty,
641            )
642        } else {
643            tracing::Span::current()
644        };
645
646        tracing::Instrument::instrument(
647            send_compatible_streaming_request(self.client.clone(), req),
648            span,
649        )
650        .await
651    }
652}
653
654#[derive(Deserialize, Debug)]
655pub struct StreamingDelta {
656    #[serde(default)]
657    content: Option<String>,
658    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
659    tool_calls: Vec<StreamingToolCall>,
660    reasoning_content: Option<String>,
661}
662
663#[derive(Deserialize, Debug)]
664struct StreamingChoice {
665    delta: StreamingDelta,
666}
667
668#[derive(Deserialize, Debug)]
669struct StreamingCompletionChunk {
670    choices: Vec<StreamingChoice>,
671    usage: Option<Usage>,
672}
673
674#[derive(Clone, Deserialize, Serialize, Debug)]
675pub struct StreamingCompletionResponse {
676    pub usage: Usage,
677}
678
679impl GetTokenUsage for StreamingCompletionResponse {
680    fn token_usage(&self) -> Option<crate::completion::Usage> {
681        let mut usage = crate::completion::Usage::new();
682        usage.input_tokens = self.usage.prompt_tokens as u64;
683        usage.output_tokens = self.usage.completion_tokens as u64;
684        usage.total_tokens = self.usage.total_tokens as u64;
685
686        Some(usage)
687    }
688}
689
690pub async fn send_compatible_streaming_request<T>(
691    http_client: T,
692    req: Request<Vec<u8>>,
693) -> Result<
694    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
695    CompletionError,
696>
697where
698    T: HttpClientExt + Clone + 'static,
699{
700    let span = tracing::Span::current();
701    let mut event_source = GenericEventSource::new(http_client, req);
702
703    let stream = stream! {
704        let mut final_usage = Usage::new();
705        let mut text_response = String::new();
706        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
707
708        while let Some(event_result) = event_source.next().await {
709            match event_result {
710                Ok(Event::Open) => {
711                    tracing::trace!("SSE connection opened");
712                    continue;
713                }
714                Ok(Event::Message(message)) => {
715                    if message.data.trim().is_empty() || message.data == "[DONE]" {
716                        continue;
717                    }
718
719                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
720                    let Ok(data) = parsed else {
721                        let err = parsed.unwrap_err();
722                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
723                        continue;
724                    };
725
726                    if let Some(choice) = data.choices.first() {
727                        let delta = &choice.delta;
728
729                        if !delta.tool_calls.is_empty() {
730                            for tool_call in &delta.tool_calls {
731                                let function = &tool_call.function;
732
733                                // Start of tool call
734                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
735                                    && empty_or_none(&function.arguments)
736                                {
737                                    let id = tool_call.id.clone().unwrap_or_default();
738                                    let name = function.name.clone().unwrap();
739                                    calls.insert(tool_call.index, (id, name, String::new()));
740                                }
741                                // Continuation of tool call
742                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
743                                    && let Some(arguments) = &function.arguments
744                                    && !arguments.is_empty()
745                                {
746                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
747                                        let combined = format!("{}{}", existing_args, arguments);
748                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
749                                    } else {
750                                        tracing::debug!("Partial tool call received but tool call was never started.");
751                                    }
752                                }
753                                // Complete tool call
754                                else {
755                                    let id = tool_call.id.clone().unwrap_or_default();
756                                    let name = function.name.clone().unwrap_or_default();
757                                    let arguments_str = function.arguments.clone().unwrap_or_default();
758
759                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
760                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
761                                        continue;
762                                    };
763
764                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
765                                        id,
766                                        name,
767                                        arguments: arguments_json,
768                                        call_id: None,
769                                    });
770                                }
771                            }
772                        }
773
774                        // DeepSeek-specific reasoning stream
775                        if let Some(content) = &delta.reasoning_content {
776                            yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
777                                reasoning: content.to_string(),
778                                id: None,
779                                signature: None,
780                            });
781                        }
782
783                        if let Some(content) = &delta.content {
784                            text_response += content;
785                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
786                        }
787                    }
788
789                    if let Some(usage) = data.usage {
790                        final_usage = usage.clone();
791                    }
792                }
793                Err(crate::http_client::Error::StreamEnded) => {
794                    break;
795                }
796                Err(err) => {
797                    tracing::error!(?err, "SSE error");
798                    yield Err(CompletionError::ResponseError(err.to_string()));
799                    break;
800                }
801            }
802        }
803
804        event_source.close();
805
806        let mut tool_calls = Vec::new();
807        // Flush accumulated tool calls
808        for (index, (id, name, arguments)) in calls {
809            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
810                continue;
811            };
812
813            tool_calls.push(ToolCall {
814                id: id.clone(),
815                index,
816                r#type: ToolType::Function,
817                function: Function {
818                    name: name.clone(),
819                    arguments: arguments_json.clone()
820                }
821            });
822            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
823                id,
824                name,
825                arguments: arguments_json,
826                call_id: None,
827            });
828        }
829
830        let message = Message::Assistant {
831            content: text_response,
832            name: None,
833            tool_calls
834        };
835
836        span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
837
838        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
839            StreamingCompletionResponse { usage: final_usage.clone() }
840        ));
841    };
842
843    Ok(crate::streaming::StreamingCompletionResponse::stream(
844        Box::pin(stream),
845    ))
846}
847
848// ================================================================
849// DeepSeek Completion API
850// ================================================================
851pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
852pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
853
854// Tests
855#[cfg(test)]
856mod tests {
857
858    use super::*;
859
860    #[test]
861    fn test_deserialize_vec_choice() {
862        let data = r#"[{
863            "finish_reason": "stop",
864            "index": 0,
865            "logprobs": null,
866            "message":{"role":"assistant","content":"Hello, world!"}
867            }]"#;
868
869        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
870        assert_eq!(choices.len(), 1);
871        match &choices.first().unwrap().message {
872            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
873            _ => panic!("Expected assistant message"),
874        }
875    }
876
877    #[test]
878    fn test_deserialize_deepseek_response() {
879        let data = r#"{
880            "choices":[{
881                "finish_reason": "stop",
882                "index": 0,
883                "logprobs": null,
884                "message":{"role":"assistant","content":"Hello, world!"}
885            }],
886            "usage": {
887                "completion_tokens": 0,
888                "prompt_tokens": 0,
889                "prompt_cache_hit_tokens": 0,
890                "prompt_cache_miss_tokens": 0,
891                "total_tokens": 0
892            }
893        }"#;
894
895        let jd = &mut serde_json::Deserializer::from_str(data);
896        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
897        match result {
898            Ok(response) => match &response.choices.first().unwrap().message {
899                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
900                _ => panic!("Expected assistant message"),
901            },
902            Err(err) => {
903                panic!("Deserialization error at {}: {}", err.path(), err);
904            }
905        }
906    }
907
908    #[test]
909    fn test_deserialize_example_response() {
910        let data = r#"
911        {
912            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
913            "object": "chat.completion",
914            "created": 0,
915            "model": "deepseek-chat",
916            "choices": [
917                {
918                    "index": 0,
919                    "message": {
920                        "role": "assistant",
921                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
922                    },
923                    "logprobs": null,
924                    "finish_reason": "stop"
925                }
926            ],
927            "usage": {
928                "prompt_tokens": 13,
929                "completion_tokens": 32,
930                "total_tokens": 45,
931                "prompt_tokens_details": {
932                    "cached_tokens": 0
933                },
934                "prompt_cache_hit_tokens": 0,
935                "prompt_cache_miss_tokens": 13
936            },
937            "system_fingerprint": "fp_4b6881f2c5"
938        }
939        "#;
940        let jd = &mut serde_json::Deserializer::from_str(data);
941        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
942
943        match result {
944            Ok(response) => match &response.choices.first().unwrap().message {
945                Message::Assistant { content, .. } => assert_eq!(
946                    content,
947                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
948                ),
949                _ => panic!("Expected assistant message"),
950            },
951            Err(err) => {
952                panic!("Deserialization error at {}: {}", err.path(), err);
953            }
954        }
955    }
956
957    #[test]
958    fn test_serialize_deserialize_tool_call_message() {
959        let tool_call_choice_json = r#"
960            {
961              "finish_reason": "tool_calls",
962              "index": 0,
963              "logprobs": null,
964              "message": {
965                "content": "",
966                "role": "assistant",
967                "tool_calls": [
968                  {
969                    "function": {
970                      "arguments": "{\"x\":2,\"y\":5}",
971                      "name": "subtract"
972                    },
973                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
974                    "index": 0,
975                    "type": "function"
976                  }
977                ]
978              }
979            }
980        "#;
981
982        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
983
984        let expected_choice: Choice = Choice {
985            finish_reason: "tool_calls".to_string(),
986            index: 0,
987            logprobs: None,
988            message: Message::Assistant {
989                content: "".to_string(),
990                name: None,
991                tool_calls: vec![ToolCall {
992                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
993                    function: Function {
994                        name: "subtract".to_string(),
995                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
996                    },
997                    index: 0,
998                    r#type: ToolType::Function,
999                }],
1000            },
1001        };
1002
1003        assert_eq!(choice, expected_choice);
1004    }
1005}