rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use crate::json_utils::empty_or_none;
13use async_stream::stream;
14use bytes::Bytes;
15use futures::StreamExt;
16use http::Request;
17use std::collections::HashMap;
18use tracing::{Instrument, Level, enabled, info_span};
19
20use crate::client::{
21    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
22    ProviderClient,
23};
24use crate::completion::GetTokenUsage;
25use crate::http_client::sse::{Event, GenericEventSource};
26use crate::http_client::{self, HttpClientExt};
27use crate::message::{Document, DocumentSourceKind};
28use crate::{
29    OneOrMany,
30    completion::{self, CompletionError, CompletionRequest},
31    json_utils, message,
32};
33use serde::{Deserialize, Serialize};
34
35use super::openai::StreamingToolCall;
36
37// ================================================================
38// Main DeepSeek Client
39// ================================================================
40const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct DeepSeekExt;
44#[derive(Debug, Default, Clone, Copy)]
45pub struct DeepSeekExtBuilder;
46
47type DeepSeekApiKey = BearerAuth;
48
49impl Provider for DeepSeekExt {
50    type Builder = DeepSeekExtBuilder;
51
52    const VERIFY_PATH: &'static str = "/user/balance";
53
54    fn build<H>(
55        _: &crate::client::ClientBuilder<
56            Self::Builder,
57            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
58            H,
59        >,
60    ) -> http_client::Result<Self> {
61        Ok(Self)
62    }
63}
64
65impl<H> Capabilities<H> for DeepSeekExt {
66    type Completion = Capable<CompletionModel<H>>;
67    type Embeddings = Nothing;
68    type Transcription = Nothing;
69    #[cfg(feature = "image")]
70    type ImageGeneration = Nothing;
71    #[cfg(feature = "audio")]
72    type AudioGeneration = Nothing;
73}
74
75impl DebugExt for DeepSeekExt {}
76
77impl ProviderBuilder for DeepSeekExtBuilder {
78    type Output = DeepSeekExt;
79    type ApiKey = DeepSeekApiKey;
80
81    const BASE_URL: &'static str = DEEPSEEK_API_BASE_URL;
82}
83
84pub type Client<H = reqwest::Client> = client::Client<DeepSeekExt, H>;
85pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<DeepSeekExtBuilder, String, H>;
86
87impl ProviderClient for Client {
88    type Input = DeepSeekApiKey;
89
90    // If you prefer the environment variable approach:
91    fn from_env() -> Self {
92        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
93        let mut client_builder = Self::builder();
94        client_builder.headers_mut().insert(
95            http::header::CONTENT_TYPE,
96            http::HeaderValue::from_static("application/json"),
97        );
98        let client_builder = client_builder.api_key(&api_key);
99        client_builder.build().unwrap()
100    }
101
102    fn from_val(input: Self::Input) -> Self {
103        Self::new(input).unwrap()
104    }
105}
106
107#[derive(Debug, Deserialize)]
108struct ApiErrorResponse {
109    message: String,
110}
111
112#[derive(Debug, Deserialize)]
113#[serde(untagged)]
114enum ApiResponse<T> {
115    Ok(T),
116    Err(ApiErrorResponse),
117}
118
119impl From<ApiErrorResponse> for CompletionError {
120    fn from(err: ApiErrorResponse) -> Self {
121        CompletionError::ProviderError(err.message)
122    }
123}
124
125/// The response shape from the DeepSeek API
126#[derive(Clone, Debug, Serialize, Deserialize)]
127pub struct CompletionResponse {
128    // We'll match the JSON:
129    pub choices: Vec<Choice>,
130    pub usage: Usage,
131    // you may want other fields
132}
133
134#[derive(Clone, Debug, Serialize, Deserialize, Default)]
135pub struct Usage {
136    pub completion_tokens: u32,
137    pub prompt_tokens: u32,
138    pub prompt_cache_hit_tokens: u32,
139    pub prompt_cache_miss_tokens: u32,
140    pub total_tokens: u32,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub completion_tokens_details: Option<CompletionTokensDetails>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    pub prompt_tokens_details: Option<PromptTokensDetails>,
145}
146
147impl Usage {
148    fn new() -> Self {
149        Self {
150            completion_tokens: 0,
151            prompt_tokens: 0,
152            prompt_cache_hit_tokens: 0,
153            prompt_cache_miss_tokens: 0,
154            total_tokens: 0,
155            completion_tokens_details: None,
156            prompt_tokens_details: None,
157        }
158    }
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct CompletionTokensDetails {
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub reasoning_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, Default)]
168pub struct PromptTokensDetails {
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub cached_tokens: Option<u32>,
171}
172
173#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
174pub struct Choice {
175    pub index: usize,
176    pub message: Message,
177    pub logprobs: Option<serde_json::Value>,
178    pub finish_reason: String,
179}
180
181#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
182#[serde(tag = "role", rename_all = "lowercase")]
183pub enum Message {
184    System {
185        content: String,
186        #[serde(skip_serializing_if = "Option::is_none")]
187        name: Option<String>,
188    },
189    User {
190        content: String,
191        #[serde(skip_serializing_if = "Option::is_none")]
192        name: Option<String>,
193    },
194    Assistant {
195        content: String,
196        #[serde(skip_serializing_if = "Option::is_none")]
197        name: Option<String>,
198        #[serde(
199            default,
200            deserialize_with = "json_utils::null_or_vec",
201            skip_serializing_if = "Vec::is_empty"
202        )]
203        tool_calls: Vec<ToolCall>,
204    },
205    #[serde(rename = "tool")]
206    ToolResult {
207        tool_call_id: String,
208        content: String,
209    },
210}
211
212impl Message {
213    pub fn system(content: &str) -> Self {
214        Message::System {
215            content: content.to_owned(),
216            name: None,
217        }
218    }
219}
220
221impl From<message::ToolResult> for Message {
222    fn from(tool_result: message::ToolResult) -> Self {
223        let content = match tool_result.content.first() {
224            message::ToolResultContent::Text(text) => text.text,
225            message::ToolResultContent::Image(_) => String::from("[Image]"),
226        };
227
228        Message::ToolResult {
229            tool_call_id: tool_result.id,
230            content,
231        }
232    }
233}
234
235impl From<message::ToolCall> for ToolCall {
236    fn from(tool_call: message::ToolCall) -> Self {
237        Self {
238            id: tool_call.id,
239            // TODO: update index when we have it
240            index: 0,
241            r#type: ToolType::Function,
242            function: Function {
243                name: tool_call.function.name,
244                arguments: tool_call.function.arguments,
245            },
246        }
247    }
248}
249
250impl TryFrom<message::Message> for Vec<Message> {
251    type Error = message::MessageError;
252
253    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
254        match message {
255            message::Message::User { content } => {
256                // extract tool results
257                let mut messages = vec![];
258
259                let tool_results = content
260                    .clone()
261                    .into_iter()
262                    .filter_map(|content| match content {
263                        message::UserContent::ToolResult(tool_result) => {
264                            Some(Message::from(tool_result))
265                        }
266                        _ => None,
267                    })
268                    .collect::<Vec<_>>();
269
270                messages.extend(tool_results);
271
272                // extract text results
273                let text_messages = content
274                    .into_iter()
275                    .filter_map(|content| match content {
276                        message::UserContent::Text(text) => Some(Message::User {
277                            content: text.text,
278                            name: None,
279                        }),
280                        message::UserContent::Document(Document {
281                            data:
282                                DocumentSourceKind::Base64(content)
283                                | DocumentSourceKind::String(content),
284                            ..
285                        }) => Some(Message::User {
286                            content,
287                            name: None,
288                        }),
289                        _ => None,
290                    })
291                    .collect::<Vec<_>>();
292                messages.extend(text_messages);
293
294                Ok(messages)
295            }
296            message::Message::Assistant { content, .. } => {
297                let mut messages: Vec<Message> = vec![];
298
299                // extract text
300                let text_content = content
301                    .clone()
302                    .into_iter()
303                    .filter_map(|content| match content {
304                        message::AssistantContent::Text(text) => Some(Message::Assistant {
305                            content: text.text,
306                            name: None,
307                            tool_calls: vec![],
308                        }),
309                        _ => None,
310                    })
311                    .collect::<Vec<_>>();
312
313                messages.extend(text_content);
314
315                // extract tool calls
316                let tool_calls = content
317                    .clone()
318                    .into_iter()
319                    .filter_map(|content| match content {
320                        message::AssistantContent::ToolCall(tool_call) => {
321                            Some(ToolCall::from(tool_call))
322                        }
323                        _ => None,
324                    })
325                    .collect::<Vec<_>>();
326
327                // if we have tool calls, we add a new Assistant message with them
328                if !tool_calls.is_empty() {
329                    messages.push(Message::Assistant {
330                        content: "".to_string(),
331                        name: None,
332                        tool_calls,
333                    });
334                }
335
336                Ok(messages)
337            }
338        }
339    }
340}
341
342#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
343pub struct ToolCall {
344    pub id: String,
345    pub index: usize,
346    #[serde(default)]
347    pub r#type: ToolType,
348    pub function: Function,
349}
350
351#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
352pub struct Function {
353    pub name: String,
354    #[serde(with = "json_utils::stringified_json")]
355    pub arguments: serde_json::Value,
356}
357
358#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
359#[serde(rename_all = "lowercase")]
360pub enum ToolType {
361    #[default]
362    Function,
363}
364
365#[derive(Clone, Debug, Deserialize, Serialize)]
366pub struct ToolDefinition {
367    pub r#type: String,
368    pub function: completion::ToolDefinition,
369}
370
371impl From<crate::completion::ToolDefinition> for ToolDefinition {
372    fn from(tool: crate::completion::ToolDefinition) -> Self {
373        Self {
374            r#type: "function".into(),
375            function: tool,
376        }
377    }
378}
379
380impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
381    type Error = CompletionError;
382
383    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
384        let choice = response.choices.first().ok_or_else(|| {
385            CompletionError::ResponseError("Response contained no choices".to_owned())
386        })?;
387        let content = match &choice.message {
388            Message::Assistant {
389                content,
390                tool_calls,
391                ..
392            } => {
393                let mut content = if content.trim().is_empty() {
394                    vec![]
395                } else {
396                    vec![completion::AssistantContent::text(content)]
397                };
398
399                content.extend(
400                    tool_calls
401                        .iter()
402                        .map(|call| {
403                            completion::AssistantContent::tool_call(
404                                &call.id,
405                                &call.function.name,
406                                call.function.arguments.clone(),
407                            )
408                        })
409                        .collect::<Vec<_>>(),
410                );
411                Ok(content)
412            }
413            _ => Err(CompletionError::ResponseError(
414                "Response did not contain a valid message or tool call".into(),
415            )),
416        }?;
417
418        let choice = OneOrMany::many(content).map_err(|_| {
419            CompletionError::ResponseError(
420                "Response contained no message or tool call (empty)".to_owned(),
421            )
422        })?;
423
424        let usage = completion::Usage {
425            input_tokens: response.usage.prompt_tokens as u64,
426            output_tokens: response.usage.completion_tokens as u64,
427            total_tokens: response.usage.total_tokens as u64,
428        };
429
430        Ok(completion::CompletionResponse {
431            choice,
432            usage,
433            raw_response: response,
434        })
435    }
436}
437
438#[derive(Debug, Serialize, Deserialize)]
439pub(super) struct DeepseekCompletionRequest {
440    model: String,
441    pub messages: Vec<Message>,
442    #[serde(skip_serializing_if = "Option::is_none")]
443    temperature: Option<f64>,
444    #[serde(skip_serializing_if = "Vec::is_empty")]
445    tools: Vec<ToolDefinition>,
446    #[serde(skip_serializing_if = "Option::is_none")]
447    tool_choice: Option<crate::providers::openrouter::ToolChoice>,
448    #[serde(flatten, skip_serializing_if = "Option::is_none")]
449    pub additional_params: Option<serde_json::Value>,
450}
451
452impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest {
453    type Error = CompletionError;
454
455    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
456        let mut full_history: Vec<Message> = match &req.preamble {
457            Some(preamble) => vec![Message::system(preamble)],
458            None => vec![],
459        };
460
461        if let Some(docs) = req.normalized_documents() {
462            let docs: Vec<Message> = docs.try_into()?;
463            full_history.extend(docs);
464        }
465
466        let chat_history: Vec<Message> = req
467            .chat_history
468            .clone()
469            .into_iter()
470            .map(|message| message.try_into())
471            .collect::<Result<Vec<Vec<Message>>, _>>()?
472            .into_iter()
473            .flatten()
474            .collect();
475
476        full_history.extend(chat_history);
477
478        let tool_choice = req
479            .tool_choice
480            .clone()
481            .map(crate::providers::openrouter::ToolChoice::try_from)
482            .transpose()?;
483
484        Ok(Self {
485            model: model.to_string(),
486            messages: full_history,
487            temperature: req.temperature,
488            tools: req
489                .tools
490                .clone()
491                .into_iter()
492                .map(ToolDefinition::from)
493                .collect::<Vec<_>>(),
494            tool_choice,
495            additional_params: req.additional_params,
496        })
497    }
498}
499
500/// The struct implementing the `CompletionModel` trait
501#[derive(Clone)]
502pub struct CompletionModel<T = reqwest::Client> {
503    pub client: Client<T>,
504    pub model: String,
505}
506
507impl<T> completion::CompletionModel for CompletionModel<T>
508where
509    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
510{
511    type Response = CompletionResponse;
512    type StreamingResponse = StreamingCompletionResponse;
513
514    type Client = Client<T>;
515
516    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
517        Self {
518            client: client.clone(),
519            model: model.into().to_string(),
520        }
521    }
522
523    async fn completion(
524        &self,
525        completion_request: CompletionRequest,
526    ) -> Result<
527        completion::CompletionResponse<CompletionResponse>,
528        crate::completion::CompletionError,
529    > {
530        let span = if tracing::Span::current().is_disabled() {
531            info_span!(
532                target: "rig::completions",
533                "chat",
534                gen_ai.operation.name = "chat",
535                gen_ai.provider.name = "deepseek",
536                gen_ai.request.model = self.model,
537                gen_ai.system_instructions = tracing::field::Empty,
538                gen_ai.response.id = tracing::field::Empty,
539                gen_ai.response.model = tracing::field::Empty,
540                gen_ai.usage.output_tokens = tracing::field::Empty,
541                gen_ai.usage.input_tokens = tracing::field::Empty,
542            )
543        } else {
544            tracing::Span::current()
545        };
546
547        span.record("gen_ai.system_instructions", &completion_request.preamble);
548
549        let request =
550            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
551
552        if enabled!(Level::TRACE) {
553            tracing::trace!(target: "rig::completions",
554                "DeepSeek completion request: {}",
555                serde_json::to_string_pretty(&request)?
556            );
557        }
558
559        let body = serde_json::to_vec(&request)?;
560        let req = self
561            .client
562            .post("/chat/completions")?
563            .body(body)
564            .map_err(|e| CompletionError::HttpError(e.into()))?;
565
566        async move {
567            let response = self.client.send::<_, Bytes>(req).await?;
568            let status = response.status();
569            let response_body = response.into_body().into_future().await?.to_vec();
570
571            if status.is_success() {
572                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
573                    ApiResponse::Ok(response) => {
574                        let span = tracing::Span::current();
575                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
576                        span.record(
577                            "gen_ai.usage.output_tokens",
578                            response.usage.completion_tokens,
579                        );
580                        if enabled!(Level::TRACE) {
581                            tracing::trace!(target: "rig::completions",
582                                "DeepSeek completion response: {}",
583                                serde_json::to_string_pretty(&response)?
584                            );
585                        }
586                        response.try_into()
587                    }
588                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
589                }
590            } else {
591                Err(CompletionError::ProviderError(
592                    String::from_utf8_lossy(&response_body).to_string(),
593                ))
594            }
595        }
596        .instrument(span)
597        .await
598    }
599
600    async fn stream(
601        &self,
602        completion_request: CompletionRequest,
603    ) -> Result<
604        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
605        CompletionError,
606    > {
607        let preamble = completion_request.preamble.clone();
608        let mut request =
609            DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
610
611        let params = json_utils::merge(
612            request.additional_params.unwrap_or(serde_json::json!({})),
613            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
614        );
615
616        request.additional_params = Some(params);
617
618        if enabled!(Level::TRACE) {
619            tracing::trace!(target: "rig::completions",
620                "DeepSeek streaming completion request: {}",
621                serde_json::to_string_pretty(&request)?
622            );
623        }
624
625        let body = serde_json::to_vec(&request)?;
626
627        let req = self
628            .client
629            .post("/chat/completions")?
630            .body(body)
631            .map_err(|e| CompletionError::HttpError(e.into()))?;
632
633        let span = if tracing::Span::current().is_disabled() {
634            info_span!(
635                target: "rig::completions",
636                "chat_streaming",
637                gen_ai.operation.name = "chat_streaming",
638                gen_ai.provider.name = "deepseek",
639                gen_ai.request.model = self.model,
640                gen_ai.system_instructions = preamble,
641                gen_ai.response.id = tracing::field::Empty,
642                gen_ai.response.model = tracing::field::Empty,
643                gen_ai.usage.output_tokens = tracing::field::Empty,
644                gen_ai.usage.input_tokens = tracing::field::Empty,
645            )
646        } else {
647            tracing::Span::current()
648        };
649
650        tracing::Instrument::instrument(
651            send_compatible_streaming_request(self.client.clone(), req),
652            span,
653        )
654        .await
655    }
656}
657
658#[derive(Deserialize, Debug)]
659pub struct StreamingDelta {
660    #[serde(default)]
661    content: Option<String>,
662    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
663    tool_calls: Vec<StreamingToolCall>,
664    reasoning_content: Option<String>,
665}
666
667#[derive(Deserialize, Debug)]
668struct StreamingChoice {
669    delta: StreamingDelta,
670}
671
672#[derive(Deserialize, Debug)]
673struct StreamingCompletionChunk {
674    choices: Vec<StreamingChoice>,
675    usage: Option<Usage>,
676}
677
678#[derive(Clone, Deserialize, Serialize, Debug)]
679pub struct StreamingCompletionResponse {
680    pub usage: Usage,
681}
682
683impl GetTokenUsage for StreamingCompletionResponse {
684    fn token_usage(&self) -> Option<crate::completion::Usage> {
685        let mut usage = crate::completion::Usage::new();
686        usage.input_tokens = self.usage.prompt_tokens as u64;
687        usage.output_tokens = self.usage.completion_tokens as u64;
688        usage.total_tokens = self.usage.total_tokens as u64;
689
690        Some(usage)
691    }
692}
693
694pub async fn send_compatible_streaming_request<T>(
695    http_client: T,
696    req: Request<Vec<u8>>,
697) -> Result<
698    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
699    CompletionError,
700>
701where
702    T: HttpClientExt + Clone + 'static,
703{
704    let span = tracing::Span::current();
705    let mut event_source = GenericEventSource::new(http_client, req);
706
707    let stream = stream! {
708        let mut final_usage = Usage::new();
709        let mut text_response = String::new();
710        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
711
712        while let Some(event_result) = event_source.next().await {
713            match event_result {
714                Ok(Event::Open) => {
715                    tracing::trace!("SSE connection opened");
716                    continue;
717                }
718                Ok(Event::Message(message)) => {
719                    if message.data.trim().is_empty() || message.data == "[DONE]" {
720                        continue;
721                    }
722
723                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
724                    let Ok(data) = parsed else {
725                        let err = parsed.unwrap_err();
726                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
727                        continue;
728                    };
729
730                    if let Some(choice) = data.choices.first() {
731                        let delta = &choice.delta;
732
733                        if !delta.tool_calls.is_empty() {
734                            for tool_call in &delta.tool_calls {
735                                let function = &tool_call.function;
736
737                                // Start of tool call
738                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
739                                    && empty_or_none(&function.arguments)
740                                {
741                                    let id = tool_call.id.clone().unwrap_or_default();
742                                    let name = function.name.clone().unwrap();
743                                    calls.insert(tool_call.index, (id, name, String::new()));
744                                }
745                                // Continuation of tool call
746                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
747                                    && let Some(arguments) = &function.arguments
748                                    && !arguments.is_empty()
749                                {
750                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
751                                        let combined = format!("{}{}", existing_args, arguments);
752                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
753                                    } else {
754                                        tracing::debug!("Partial tool call received but tool call was never started.");
755                                    }
756                                }
757                                // Complete tool call
758                                else {
759                                    let id = tool_call.id.clone().unwrap_or_default();
760                                    let name = function.name.clone().unwrap_or_default();
761                                    let arguments_str = function.arguments.clone().unwrap_or_default();
762
763                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
764                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
765                                        continue;
766                                    };
767
768                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
769                                        crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
770                                    ));
771                                }
772                            }
773                        }
774
775                        // DeepSeek-specific reasoning stream
776                        if let Some(content) = &delta.reasoning_content {
777                            yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
778                                id: None,
779                                reasoning: content.to_string()
780                            });
781                        }
782
783                        if let Some(content) = &delta.content {
784                            text_response += content;
785                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
786                        }
787                    }
788
789                    if let Some(usage) = data.usage {
790                        final_usage = usage.clone();
791                    }
792                }
793                Err(crate::http_client::Error::StreamEnded) => {
794                    break;
795                }
796                Err(err) => {
797                    tracing::error!(?err, "SSE error");
798                    yield Err(CompletionError::ResponseError(err.to_string()));
799                    break;
800                }
801            }
802        }
803
804        event_source.close();
805
806        let mut tool_calls = Vec::new();
807        // Flush accumulated tool calls
808        for (index, (id, name, arguments)) in calls {
809            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
810                continue;
811            };
812
813            tool_calls.push(ToolCall {
814                id: id.clone(),
815                index,
816                r#type: ToolType::Function,
817                function: Function {
818                    name: name.clone(),
819                    arguments: arguments_json.clone()
820                }
821            });
822            yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
823                crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
824            ));
825        }
826
827        let message = Message::Assistant {
828            content: text_response,
829            name: None,
830            tool_calls
831        };
832
833        span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
834
835        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
836            StreamingCompletionResponse { usage: final_usage.clone() }
837        ));
838    };
839
840    Ok(crate::streaming::StreamingCompletionResponse::stream(
841        Box::pin(stream),
842    ))
843}
844
845// ================================================================
846// DeepSeek Completion API
847// ================================================================
848pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
849pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
850
851// Tests
852#[cfg(test)]
853mod tests {
854
855    use super::*;
856
857    #[test]
858    fn test_deserialize_vec_choice() {
859        let data = r#"[{
860            "finish_reason": "stop",
861            "index": 0,
862            "logprobs": null,
863            "message":{"role":"assistant","content":"Hello, world!"}
864            }]"#;
865
866        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
867        assert_eq!(choices.len(), 1);
868        match &choices.first().unwrap().message {
869            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
870            _ => panic!("Expected assistant message"),
871        }
872    }
873
874    #[test]
875    fn test_deserialize_deepseek_response() {
876        let data = r#"{
877            "choices":[{
878                "finish_reason": "stop",
879                "index": 0,
880                "logprobs": null,
881                "message":{"role":"assistant","content":"Hello, world!"}
882            }],
883            "usage": {
884                "completion_tokens": 0,
885                "prompt_tokens": 0,
886                "prompt_cache_hit_tokens": 0,
887                "prompt_cache_miss_tokens": 0,
888                "total_tokens": 0
889            }
890        }"#;
891
892        let jd = &mut serde_json::Deserializer::from_str(data);
893        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
894        match result {
895            Ok(response) => match &response.choices.first().unwrap().message {
896                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
897                _ => panic!("Expected assistant message"),
898            },
899            Err(err) => {
900                panic!("Deserialization error at {}: {}", err.path(), err);
901            }
902        }
903    }
904
905    #[test]
906    fn test_deserialize_example_response() {
907        let data = r#"
908        {
909            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
910            "object": "chat.completion",
911            "created": 0,
912            "model": "deepseek-chat",
913            "choices": [
914                {
915                    "index": 0,
916                    "message": {
917                        "role": "assistant",
918                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
919                    },
920                    "logprobs": null,
921                    "finish_reason": "stop"
922                }
923            ],
924            "usage": {
925                "prompt_tokens": 13,
926                "completion_tokens": 32,
927                "total_tokens": 45,
928                "prompt_tokens_details": {
929                    "cached_tokens": 0
930                },
931                "prompt_cache_hit_tokens": 0,
932                "prompt_cache_miss_tokens": 13
933            },
934            "system_fingerprint": "fp_4b6881f2c5"
935        }
936        "#;
937        let jd = &mut serde_json::Deserializer::from_str(data);
938        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
939
940        match result {
941            Ok(response) => match &response.choices.first().unwrap().message {
942                Message::Assistant { content, .. } => assert_eq!(
943                    content,
944                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
945                ),
946                _ => panic!("Expected assistant message"),
947            },
948            Err(err) => {
949                panic!("Deserialization error at {}: {}", err.path(), err);
950            }
951        }
952    }
953
954    #[test]
955    fn test_serialize_deserialize_tool_call_message() {
956        let tool_call_choice_json = r#"
957            {
958              "finish_reason": "tool_calls",
959              "index": 0,
960              "logprobs": null,
961              "message": {
962                "content": "",
963                "role": "assistant",
964                "tool_calls": [
965                  {
966                    "function": {
967                      "arguments": "{\"x\":2,\"y\":5}",
968                      "name": "subtract"
969                    },
970                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
971                    "index": 0,
972                    "type": "function"
973                  }
974                ]
975              }
976            }
977        "#;
978
979        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
980
981        let expected_choice: Choice = Choice {
982            finish_reason: "tool_calls".to_string(),
983            index: 0,
984            logprobs: None,
985            message: Message::Assistant {
986                content: "".to_string(),
987                name: None,
988                tool_calls: vec![ToolCall {
989                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
990                    function: Function {
991                        name: "subtract".to_string(),
992                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
993                    },
994                    index: 0,
995                    r#type: ToolType::Function,
996                }],
997            },
998        };
999
1000        assert_eq!(choice, expected_choice);
1001    }
1002}