rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use async_stream::stream;
13use bytes::Bytes;
14use futures::StreamExt;
15use reqwest_eventsource::{Event, RequestBuilderExt};
16use std::collections::HashMap;
17use tracing::{Instrument, info_span};
18
19use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
20use crate::completion::GetTokenUsage;
21use crate::http_client::{self, HttpClientExt};
22use crate::json_utils::merge;
23use crate::message::{Document, DocumentSourceKind};
24use crate::{
25    OneOrMany,
26    completion::{self, CompletionError, CompletionRequest},
27    impl_conversion_traits, json_utils, message,
28};
29use serde::{Deserialize, Serialize};
30use serde_json::json;
31
32use super::openai::StreamingToolCall;
33
34// ================================================================
35// Main DeepSeek Client
36// ================================================================
37const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
38
39pub struct ClientBuilder<'a, T = reqwest::Client> {
40    api_key: &'a str,
41    base_url: &'a str,
42    http_client: T,
43}
44
45impl<'a, T> ClientBuilder<'a, T>
46where
47    T: Default,
48{
49    pub fn new(api_key: &'a str) -> Self {
50        Self {
51            api_key,
52            base_url: DEEPSEEK_API_BASE_URL,
53            http_client: Default::default(),
54        }
55    }
56}
57
58impl<'a, T> ClientBuilder<'a, T> {
59    pub fn base_url(mut self, base_url: &'a str) -> Self {
60        self.base_url = base_url;
61        self
62    }
63
64    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
65        ClientBuilder {
66            api_key: self.api_key,
67            base_url: self.base_url,
68            http_client,
69        }
70    }
71
72    pub fn build(self) -> Client<T> {
73        Client {
74            base_url: self.base_url.to_string(),
75            api_key: self.api_key.to_string(),
76            http_client: self.http_client,
77        }
78    }
79}
80
81#[derive(Clone)]
82pub struct Client<T = reqwest::Client> {
83    pub base_url: String,
84    api_key: String,
85    http_client: T,
86}
87
88impl<T> std::fmt::Debug for Client<T>
89where
90    T: std::fmt::Debug,
91{
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("Client")
94            .field("base_url", &self.base_url)
95            .field("http_client", &self.http_client)
96            .field("api_key", &"<REDACTED>")
97            .finish()
98    }
99}
100
101impl<T> Client<T>
102where
103    T: Default,
104{
105    /// Create a new DeepSeek client builder.
106    ///
107    /// # Example
108    /// ```
109    /// use rig::providers::deepseek::{ClientBuilder, self};
110    ///
111    /// // Initialize the DeepSeek client
112    /// let deepseek = Client::builder("your-deepseek-api-key")
113    ///    .build()
114    /// ```
115    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
116        ClientBuilder::new(api_key)
117    }
118
119    /// Create a new DeepSeek client. For more control, use the `builder` method.
120    ///
121    /// # Panics
122    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
123    pub fn new(api_key: &str) -> Self {
124        Self::builder(api_key).build()
125    }
126}
127
128impl<T> Client<T>
129where
130    T: HttpClientExt,
131{
132    fn req(
133        &self,
134        method: http_client::Method,
135        path: &str,
136    ) -> http_client::Result<http_client::Builder> {
137        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
138
139        http_client::with_bearer_auth(
140            http_client::Request::builder().method(method).uri(url),
141            &self.api_key,
142        )
143    }
144
145    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
146        self.req(http_client::Method::GET, path)
147    }
148
149    async fn send<U, R>(
150        &self,
151        req: http_client::Request<U>,
152    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
153    where
154        U: Into<Bytes> + Send,
155        R: From<Bytes> + Send + 'static,
156    {
157        self.http_client.send(req).await
158    }
159}
160
161impl Client<reqwest::Client> {
162    fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
163        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
164
165        self.http_client.post(url).bearer_auth(&self.api_key)
166    }
167}
168
169impl ProviderClient for Client<reqwest::Client> {
170    // If you prefer the environment variable approach:
171    fn from_env() -> Self {
172        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
173        Self::new(&api_key)
174    }
175
176    fn from_val(input: crate::client::ProviderValue) -> Self {
177        let crate::client::ProviderValue::Simple(api_key) = input else {
178            panic!("Incorrect provider value type")
179        };
180        Self::new(&api_key)
181    }
182}
183
184impl CompletionClient for Client<reqwest::Client> {
185    type CompletionModel = CompletionModel<reqwest::Client>;
186
187    /// Creates a DeepSeek completion model with the given `model_name`.
188    fn completion_model(&self, model_name: &str) -> CompletionModel<reqwest::Client> {
189        CompletionModel {
190            client: self.clone(),
191            model: model_name.to_string(),
192        }
193    }
194}
195
196impl VerifyClient for Client<reqwest::Client> {
197    #[cfg_attr(feature = "worker", worker::send)]
198    async fn verify(&self) -> Result<(), VerifyError> {
199        let req = self
200            .get("/user/balance")?
201            .body(http_client::NoBody)
202            .map_err(http_client::Error::from)?;
203
204        let response = self.send(req).await?;
205
206        match response.status() {
207            reqwest::StatusCode::OK => Ok(()),
208            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
209            reqwest::StatusCode::INTERNAL_SERVER_ERROR
210            | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
211                let text = http_client::text(response).await?;
212                Err(VerifyError::ProviderError(text))
213            }
214            _ => {
215                // TODO: `HttpClientExt` equivalent
216                //response.error_for_status()?;
217                Ok(())
218            }
219        }
220    }
221}
222
223impl_conversion_traits!(
224    AsEmbeddings,
225    AsTranscription,
226    AsImageGeneration,
227    AsAudioGeneration for Client<T>
228);
229
230#[derive(Debug, Deserialize)]
231struct ApiErrorResponse {
232    message: String,
233}
234
235#[derive(Debug, Deserialize)]
236#[serde(untagged)]
237enum ApiResponse<T> {
238    Ok(T),
239    Err(ApiErrorResponse),
240}
241
242impl From<ApiErrorResponse> for CompletionError {
243    fn from(err: ApiErrorResponse) -> Self {
244        CompletionError::ProviderError(err.message)
245    }
246}
247
248/// The response shape from the DeepSeek API
249#[derive(Clone, Debug, Serialize, Deserialize)]
250pub struct CompletionResponse {
251    // We'll match the JSON:
252    pub choices: Vec<Choice>,
253    pub usage: Usage,
254    // you may want other fields
255}
256
257#[derive(Clone, Debug, Serialize, Deserialize, Default)]
258pub struct Usage {
259    pub completion_tokens: u32,
260    pub prompt_tokens: u32,
261    pub prompt_cache_hit_tokens: u32,
262    pub prompt_cache_miss_tokens: u32,
263    pub total_tokens: u32,
264    #[serde(skip_serializing_if = "Option::is_none")]
265    pub completion_tokens_details: Option<CompletionTokensDetails>,
266    #[serde(skip_serializing_if = "Option::is_none")]
267    pub prompt_tokens_details: Option<PromptTokensDetails>,
268}
269
270impl Usage {
271    fn new() -> Self {
272        Self {
273            completion_tokens: 0,
274            prompt_tokens: 0,
275            prompt_cache_hit_tokens: 0,
276            prompt_cache_miss_tokens: 0,
277            total_tokens: 0,
278            completion_tokens_details: None,
279            prompt_tokens_details: None,
280        }
281    }
282}
283
284#[derive(Clone, Debug, Serialize, Deserialize, Default)]
285pub struct CompletionTokensDetails {
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub reasoning_tokens: Option<u32>,
288}
289
290#[derive(Clone, Debug, Serialize, Deserialize, Default)]
291pub struct PromptTokensDetails {
292    #[serde(skip_serializing_if = "Option::is_none")]
293    pub cached_tokens: Option<u32>,
294}
295
296#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
297pub struct Choice {
298    pub index: usize,
299    pub message: Message,
300    pub logprobs: Option<serde_json::Value>,
301    pub finish_reason: String,
302}
303
304#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
305#[serde(tag = "role", rename_all = "lowercase")]
306pub enum Message {
307    System {
308        content: String,
309        #[serde(skip_serializing_if = "Option::is_none")]
310        name: Option<String>,
311    },
312    User {
313        content: String,
314        #[serde(skip_serializing_if = "Option::is_none")]
315        name: Option<String>,
316    },
317    Assistant {
318        content: String,
319        #[serde(skip_serializing_if = "Option::is_none")]
320        name: Option<String>,
321        #[serde(
322            default,
323            deserialize_with = "json_utils::null_or_vec",
324            skip_serializing_if = "Vec::is_empty"
325        )]
326        tool_calls: Vec<ToolCall>,
327    },
328    #[serde(rename = "tool")]
329    ToolResult {
330        tool_call_id: String,
331        content: String,
332    },
333}
334
335impl Message {
336    pub fn system(content: &str) -> Self {
337        Message::System {
338            content: content.to_owned(),
339            name: None,
340        }
341    }
342}
343
344impl From<message::ToolResult> for Message {
345    fn from(tool_result: message::ToolResult) -> Self {
346        let content = match tool_result.content.first() {
347            message::ToolResultContent::Text(text) => text.text,
348            message::ToolResultContent::Image(_) => String::from("[Image]"),
349        };
350
351        Message::ToolResult {
352            tool_call_id: tool_result.id,
353            content,
354        }
355    }
356}
357
358impl From<message::ToolCall> for ToolCall {
359    fn from(tool_call: message::ToolCall) -> Self {
360        Self {
361            id: tool_call.id,
362            // TODO: update index when we have it
363            index: 0,
364            r#type: ToolType::Function,
365            function: Function {
366                name: tool_call.function.name,
367                arguments: tool_call.function.arguments,
368            },
369        }
370    }
371}
372
373impl TryFrom<message::Message> for Vec<Message> {
374    type Error = message::MessageError;
375
376    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
377        match message {
378            message::Message::User { content } => {
379                // extract tool results
380                let mut messages = vec![];
381
382                let tool_results = content
383                    .clone()
384                    .into_iter()
385                    .filter_map(|content| match content {
386                        message::UserContent::ToolResult(tool_result) => {
387                            Some(Message::from(tool_result))
388                        }
389                        _ => None,
390                    })
391                    .collect::<Vec<_>>();
392
393                messages.extend(tool_results);
394
395                // extract text results
396                let text_messages = content
397                    .into_iter()
398                    .filter_map(|content| match content {
399                        message::UserContent::Text(text) => Some(Message::User {
400                            content: text.text,
401                            name: None,
402                        }),
403                        message::UserContent::Document(Document {
404                            data:
405                                DocumentSourceKind::Base64(content)
406                                | DocumentSourceKind::String(content),
407                            ..
408                        }) => Some(Message::User {
409                            content,
410                            name: None,
411                        }),
412                        _ => None,
413                    })
414                    .collect::<Vec<_>>();
415                messages.extend(text_messages);
416
417                Ok(messages)
418            }
419            message::Message::Assistant { content, .. } => {
420                let mut messages: Vec<Message> = vec![];
421
422                // extract text
423                let text_content = content
424                    .clone()
425                    .into_iter()
426                    .filter_map(|content| match content {
427                        message::AssistantContent::Text(text) => Some(Message::Assistant {
428                            content: text.text,
429                            name: None,
430                            tool_calls: vec![],
431                        }),
432                        _ => None,
433                    })
434                    .collect::<Vec<_>>();
435
436                messages.extend(text_content);
437
438                // extract tool calls
439                let tool_calls = content
440                    .clone()
441                    .into_iter()
442                    .filter_map(|content| match content {
443                        message::AssistantContent::ToolCall(tool_call) => {
444                            Some(ToolCall::from(tool_call))
445                        }
446                        _ => None,
447                    })
448                    .collect::<Vec<_>>();
449
450                // if we have tool calls, we add a new Assistant message with them
451                if !tool_calls.is_empty() {
452                    messages.push(Message::Assistant {
453                        content: "".to_string(),
454                        name: None,
455                        tool_calls,
456                    });
457                }
458
459                Ok(messages)
460            }
461        }
462    }
463}
464
465#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
466pub struct ToolCall {
467    pub id: String,
468    pub index: usize,
469    #[serde(default)]
470    pub r#type: ToolType,
471    pub function: Function,
472}
473
474#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
475pub struct Function {
476    pub name: String,
477    #[serde(with = "json_utils::stringified_json")]
478    pub arguments: serde_json::Value,
479}
480
481#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
482#[serde(rename_all = "lowercase")]
483pub enum ToolType {
484    #[default]
485    Function,
486}
487
488#[derive(Clone, Debug, Deserialize, Serialize)]
489pub struct ToolDefinition {
490    pub r#type: String,
491    pub function: completion::ToolDefinition,
492}
493
494impl From<crate::completion::ToolDefinition> for ToolDefinition {
495    fn from(tool: crate::completion::ToolDefinition) -> Self {
496        Self {
497            r#type: "function".into(),
498            function: tool,
499        }
500    }
501}
502
503impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
504    type Error = CompletionError;
505
506    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
507        let choice = response.choices.first().ok_or_else(|| {
508            CompletionError::ResponseError("Response contained no choices".to_owned())
509        })?;
510        let content = match &choice.message {
511            Message::Assistant {
512                content,
513                tool_calls,
514                ..
515            } => {
516                let mut content = if content.trim().is_empty() {
517                    vec![]
518                } else {
519                    vec![completion::AssistantContent::text(content)]
520                };
521
522                content.extend(
523                    tool_calls
524                        .iter()
525                        .map(|call| {
526                            completion::AssistantContent::tool_call(
527                                &call.id,
528                                &call.function.name,
529                                call.function.arguments.clone(),
530                            )
531                        })
532                        .collect::<Vec<_>>(),
533                );
534                Ok(content)
535            }
536            _ => Err(CompletionError::ResponseError(
537                "Response did not contain a valid message or tool call".into(),
538            )),
539        }?;
540
541        let choice = OneOrMany::many(content).map_err(|_| {
542            CompletionError::ResponseError(
543                "Response contained no message or tool call (empty)".to_owned(),
544            )
545        })?;
546
547        let usage = completion::Usage {
548            input_tokens: response.usage.prompt_tokens as u64,
549            output_tokens: response.usage.completion_tokens as u64,
550            total_tokens: response.usage.total_tokens as u64,
551        };
552
553        Ok(completion::CompletionResponse {
554            choice,
555            usage,
556            raw_response: response,
557        })
558    }
559}
560
561/// The struct implementing the `CompletionModel` trait
562#[derive(Clone)]
563pub struct CompletionModel<T = reqwest::Client> {
564    pub client: Client<T>,
565    pub model: String,
566}
567
568impl<T> CompletionModel<T> {
569    fn create_completion_request(
570        &self,
571        completion_request: CompletionRequest,
572    ) -> Result<serde_json::Value, CompletionError> {
573        // Build up the order of messages (context, chat_history, prompt)
574        let mut partial_history = vec![];
575
576        if let Some(docs) = completion_request.normalized_documents() {
577            partial_history.push(docs);
578        }
579
580        partial_history.extend(completion_request.chat_history);
581
582        // Initialize full history with preamble (or empty if non-existent)
583        let mut full_history: Vec<Message> = completion_request
584            .preamble
585            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
586
587        // Convert and extend the rest of the history
588        full_history.extend(
589            partial_history
590                .into_iter()
591                .map(message::Message::try_into)
592                .collect::<Result<Vec<Vec<Message>>, _>>()?
593                .into_iter()
594                .flatten()
595                .collect::<Vec<_>>(),
596        );
597
598        let tool_choice = completion_request
599            .tool_choice
600            .map(crate::providers::openrouter::ToolChoice::try_from)
601            .transpose()?;
602
603        let request = if completion_request.tools.is_empty() {
604            json!({
605                "model": self.model,
606                "messages": full_history,
607                "temperature": completion_request.temperature,
608            })
609        } else {
610            json!({
611                "model": self.model,
612                "messages": full_history,
613                "temperature": completion_request.temperature,
614                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
615                "tool_choice": tool_choice,
616            })
617        };
618
619        let request = if let Some(params) = completion_request.additional_params {
620            json_utils::merge(request, params)
621        } else {
622            request
623        };
624
625        Ok(request)
626    }
627}
628
629impl completion::CompletionModel for CompletionModel<reqwest::Client> {
630    type Response = CompletionResponse;
631    type StreamingResponse = StreamingCompletionResponse;
632
633    #[cfg_attr(feature = "worker", worker::send)]
634    async fn completion(
635        &self,
636        completion_request: CompletionRequest,
637    ) -> Result<
638        completion::CompletionResponse<CompletionResponse>,
639        crate::completion::CompletionError,
640    > {
641        let preamble = completion_request.preamble.clone();
642        let request = self.create_completion_request(completion_request)?;
643
644        let span = if tracing::Span::current().is_disabled() {
645            info_span!(
646                target: "rig::completions",
647                "chat",
648                gen_ai.operation.name = "chat",
649                gen_ai.provider.name = "deepseek",
650                gen_ai.request.model = self.model,
651                gen_ai.system_instructions = preamble,
652                gen_ai.response.id = tracing::field::Empty,
653                gen_ai.response.model = tracing::field::Empty,
654                gen_ai.usage.output_tokens = tracing::field::Empty,
655                gen_ai.usage.input_tokens = tracing::field::Empty,
656                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
657                gen_ai.output.messages = tracing::field::Empty,
658            )
659        } else {
660            tracing::Span::current()
661        };
662
663        tracing::debug!("DeepSeek completion request: {request:?}");
664
665        async move {
666            let response = self
667                .client
668                .reqwest_post("/chat/completions")
669                .json(&request)
670                .send()
671                .await
672                .map_err(|e| http_client::Error::Instance(e.into()))?;
673
674            if response.status().is_success() {
675                let t = response
676                    .text()
677                    .await
678                    .map_err(|e| http_client::Error::Instance(e.into()))?;
679
680                tracing::debug!(target: "rig", "DeepSeek completion: {t}");
681
682                match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
683                    ApiResponse::Ok(response) => {
684                        let span = tracing::Span::current();
685                        span.record(
686                            "gen_ai.output.messages",
687                            serde_json::to_string(&response.choices).unwrap(),
688                        );
689                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
690                        span.record(
691                            "gen_ai.usage.output_tokens",
692                            response.usage.completion_tokens,
693                        );
694                        response.try_into()
695                    }
696                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
697                }
698            } else {
699                Err(CompletionError::ProviderError(
700                    response
701                        .text()
702                        .await
703                        .map_err(|e| http_client::Error::Instance(e.into()))?,
704                ))
705            }
706        }
707        .instrument(span)
708        .await
709    }
710
711    #[cfg_attr(feature = "worker", worker::send)]
712    async fn stream(
713        &self,
714        completion_request: CompletionRequest,
715    ) -> Result<
716        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
717        CompletionError,
718    > {
719        let preamble = completion_request.preamble.clone();
720        let mut request = self.create_completion_request(completion_request)?;
721
722        request = merge(
723            request,
724            json!({"stream": true, "stream_options": {"include_usage": true}}),
725        );
726
727        let builder = self.client.reqwest_post("/chat/completions").json(&request);
728
729        let span = if tracing::Span::current().is_disabled() {
730            info_span!(
731                target: "rig::completions",
732                "chat_streaming",
733                gen_ai.operation.name = "chat_streaming",
734                gen_ai.provider.name = "deepseek",
735                gen_ai.request.model = self.model,
736                gen_ai.system_instructions = preamble,
737                gen_ai.response.id = tracing::field::Empty,
738                gen_ai.response.model = tracing::field::Empty,
739                gen_ai.usage.output_tokens = tracing::field::Empty,
740                gen_ai.usage.input_tokens = tracing::field::Empty,
741                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
742                gen_ai.output.messages = tracing::field::Empty,
743            )
744        } else {
745            tracing::Span::current()
746        };
747
748        tracing::Instrument::instrument(send_compatible_streaming_request(builder), span).await
749    }
750}
751
752#[derive(Deserialize, Debug)]
753pub struct StreamingDelta {
754    #[serde(default)]
755    content: Option<String>,
756    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
757    tool_calls: Vec<StreamingToolCall>,
758    reasoning_content: Option<String>,
759}
760
761#[derive(Deserialize, Debug)]
762struct StreamingChoice {
763    delta: StreamingDelta,
764}
765
766#[derive(Deserialize, Debug)]
767struct StreamingCompletionChunk {
768    choices: Vec<StreamingChoice>,
769    usage: Option<Usage>,
770}
771
772#[derive(Clone, Deserialize, Serialize, Debug)]
773pub struct StreamingCompletionResponse {
774    pub usage: Usage,
775}
776
777impl GetTokenUsage for StreamingCompletionResponse {
778    fn token_usage(&self) -> Option<crate::completion::Usage> {
779        let mut usage = crate::completion::Usage::new();
780        usage.input_tokens = self.usage.prompt_tokens as u64;
781        usage.output_tokens = self.usage.completion_tokens as u64;
782        usage.total_tokens = self.usage.total_tokens as u64;
783
784        Some(usage)
785    }
786}
787
788pub async fn send_compatible_streaming_request(
789    request_builder: reqwest::RequestBuilder,
790) -> Result<
791    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
792    CompletionError,
793> {
794    let span = tracing::Span::current();
795    let mut event_source = request_builder
796        .eventsource()
797        .expect("Cloning request must succeed");
798
799    let stream = stream! {
800        let mut final_usage = Usage::new();
801        let mut text_response = String::new();
802        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
803
804        while let Some(event_result) = event_source.next().await {
805            match event_result {
806                Ok(Event::Open) => {
807                    tracing::trace!("SSE connection opened");
808                    continue;
809                }
810                Ok(Event::Message(message)) => {
811                    if message.data.trim().is_empty() || message.data == "[DONE]" {
812                        continue;
813                    }
814
815                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
816                    let Ok(data) = parsed else {
817                        let err = parsed.unwrap_err();
818                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
819                        continue;
820                    };
821
822                    if let Some(choice) = data.choices.first() {
823                        let delta = &choice.delta;
824
825                        if !delta.tool_calls.is_empty() {
826                            for tool_call in &delta.tool_calls {
827                                let function = &tool_call.function;
828
829                                // Start of tool call
830                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
831                                    && function.arguments.is_empty()
832                                {
833                                    let id = tool_call.id.clone().unwrap_or_default();
834                                    let name = function.name.clone().unwrap();
835                                    calls.insert(tool_call.index, (id, name, String::new()));
836                                }
837                                // Continuation of tool call
838                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
839                                    && !function.arguments.is_empty()
840                                {
841                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
842                                        let combined = format!("{}{}", existing_args, function.arguments);
843                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
844                                    } else {
845                                        tracing::debug!("Partial tool call received but tool call was never started.");
846                                    }
847                                }
848                                // Complete tool call
849                                else {
850                                    let id = tool_call.id.clone().unwrap_or_default();
851                                    let name = function.name.clone().unwrap_or_default();
852                                    let arguments_str = function.arguments.clone();
853
854                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
855                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
856                                        continue;
857                                    };
858
859                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
860                                        id,
861                                        name,
862                                        arguments: arguments_json,
863                                        call_id: None,
864                                    });
865                                }
866                            }
867                        }
868
869                        // DeepSeek-specific reasoning stream
870                        if let Some(content) = &delta.reasoning_content {
871                            yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
872                                reasoning: content.to_string(),
873                                id: None,
874                            });
875                        }
876
877                        if let Some(content) = &delta.content {
878                            text_response += content;
879                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
880                        }
881                    }
882
883                    if let Some(usage) = data.usage {
884                        final_usage = usage.clone();
885                    }
886                }
887                Err(reqwest_eventsource::Error::StreamEnded) => {
888                    break;
889                }
890                Err(err) => {
891                    tracing::error!(?err, "SSE error");
892                    yield Err(CompletionError::ResponseError(err.to_string()));
893                    break;
894                }
895            }
896        }
897
898        let mut tool_calls = Vec::new();
899        // Flush accumulated tool calls
900        for (index, (id, name, arguments)) in calls {
901            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
902                continue;
903            };
904
905            tool_calls.push(ToolCall {
906                id: id.clone(),
907                index,
908                r#type: ToolType::Function,
909                function: Function {
910                    name: name.clone(),
911                    arguments: arguments_json.clone()
912                }
913            });
914            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
915                id,
916                name,
917                arguments: arguments_json,
918                call_id: None,
919            });
920        }
921
922        let message = Message::Assistant {
923            content: text_response,
924            name: None,
925            tool_calls
926        };
927
928        span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
929
930        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
931            StreamingCompletionResponse { usage: final_usage.clone() }
932        ));
933    };
934
935    Ok(crate::streaming::StreamingCompletionResponse::stream(
936        Box::pin(stream),
937    ))
938}
939
940// ================================================================
941// DeepSeek Completion API
942// ================================================================
943
944/// `deepseek-chat` completion model
945pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
946/// `deepseek-reasoner` completion model
947pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
948
949// Tests
950#[cfg(test)]
951mod tests {
952
953    use super::*;
954
955    #[test]
956    fn test_deserialize_vec_choice() {
957        let data = r#"[{
958            "finish_reason": "stop",
959            "index": 0,
960            "logprobs": null,
961            "message":{"role":"assistant","content":"Hello, world!"}
962            }]"#;
963
964        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
965        assert_eq!(choices.len(), 1);
966        match &choices.first().unwrap().message {
967            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
968            _ => panic!("Expected assistant message"),
969        }
970    }
971
972    #[test]
973    fn test_deserialize_deepseek_response() {
974        let data = r#"{
975            "choices":[{
976                "finish_reason": "stop",
977                "index": 0,
978                "logprobs": null,
979                "message":{"role":"assistant","content":"Hello, world!"}
980            }],
981            "usage": {
982                "completion_tokens": 0,
983                "prompt_tokens": 0,
984                "prompt_cache_hit_tokens": 0,
985                "prompt_cache_miss_tokens": 0,
986                "total_tokens": 0
987            }
988        }"#;
989
990        let jd = &mut serde_json::Deserializer::from_str(data);
991        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
992        match result {
993            Ok(response) => match &response.choices.first().unwrap().message {
994                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
995                _ => panic!("Expected assistant message"),
996            },
997            Err(err) => {
998                panic!("Deserialization error at {}: {}", err.path(), err);
999            }
1000        }
1001    }
1002
1003    #[test]
1004    fn test_deserialize_example_response() {
1005        let data = r#"
1006        {
1007            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
1008            "object": "chat.completion",
1009            "created": 0,
1010            "model": "deepseek-chat",
1011            "choices": [
1012                {
1013                    "index": 0,
1014                    "message": {
1015                        "role": "assistant",
1016                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
1017                    },
1018                    "logprobs": null,
1019                    "finish_reason": "stop"
1020                }
1021            ],
1022            "usage": {
1023                "prompt_tokens": 13,
1024                "completion_tokens": 32,
1025                "total_tokens": 45,
1026                "prompt_tokens_details": {
1027                    "cached_tokens": 0
1028                },
1029                "prompt_cache_hit_tokens": 0,
1030                "prompt_cache_miss_tokens": 13
1031            },
1032            "system_fingerprint": "fp_4b6881f2c5"
1033        }
1034        "#;
1035        let jd = &mut serde_json::Deserializer::from_str(data);
1036        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
1037
1038        match result {
1039            Ok(response) => match &response.choices.first().unwrap().message {
1040                Message::Assistant { content, .. } => assert_eq!(
1041                    content,
1042                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
1043                ),
1044                _ => panic!("Expected assistant message"),
1045            },
1046            Err(err) => {
1047                panic!("Deserialization error at {}: {}", err.path(), err);
1048            }
1049        }
1050    }
1051
1052    #[test]
1053    fn test_serialize_deserialize_tool_call_message() {
1054        let tool_call_choice_json = r#"
1055            {
1056              "finish_reason": "tool_calls",
1057              "index": 0,
1058              "logprobs": null,
1059              "message": {
1060                "content": "",
1061                "role": "assistant",
1062                "tool_calls": [
1063                  {
1064                    "function": {
1065                      "arguments": "{\"x\":2,\"y\":5}",
1066                      "name": "subtract"
1067                    },
1068                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1069                    "index": 0,
1070                    "type": "function"
1071                  }
1072                ]
1073              }
1074            }
1075        "#;
1076
1077        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1078
1079        let expected_choice: Choice = Choice {
1080            finish_reason: "tool_calls".to_string(),
1081            index: 0,
1082            logprobs: None,
1083            message: Message::Assistant {
1084                content: "".to_string(),
1085                name: None,
1086                tool_calls: vec![ToolCall {
1087                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1088                    function: Function {
1089                        name: "subtract".to_string(),
1090                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1091                    },
1092                    index: 0,
1093                    r#type: ToolType::Function,
1094                }],
1095            },
1096        };
1097
1098        assert_eq!(choice, expected_choice);
1099    }
1100}