rig/providers/
deepseek.rs

1//! DeepSeek API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::deepseek;
6//!
7//! let client = deepseek::Client::new("DEEPSEEK_API_KEY");
8//!
9//! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT);
10//! ```
11
12use async_stream::stream;
13use bytes::Bytes;
14use futures::StreamExt;
15use http::{Method, Request};
16use std::collections::HashMap;
17use tracing::{Instrument, info_span};
18
19use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
20use crate::completion::GetTokenUsage;
21use crate::http_client::sse::{Event, GenericEventSource};
22use crate::http_client::{self, HttpClientExt};
23use crate::json_utils::merge;
24use crate::message::{Document, DocumentSourceKind};
25use crate::{
26    OneOrMany,
27    completion::{self, CompletionError, CompletionRequest},
28    impl_conversion_traits, json_utils, message,
29};
30use serde::{Deserialize, Serialize};
31use serde_json::json;
32
33use super::openai::StreamingToolCall;
34
35// ================================================================
36// Main DeepSeek Client
37// ================================================================
38const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
39
40pub struct ClientBuilder<'a, T = reqwest::Client> {
41    api_key: &'a str,
42    base_url: &'a str,
43    http_client: T,
44}
45
46impl<'a, T> ClientBuilder<'a, T>
47where
48    T: Default,
49{
50    pub fn new(api_key: &'a str) -> Self {
51        Self {
52            api_key,
53            base_url: DEEPSEEK_API_BASE_URL,
54            http_client: Default::default(),
55        }
56    }
57}
58
59impl<'a, T> ClientBuilder<'a, T> {
60    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
61        Self {
62            api_key,
63            base_url: DEEPSEEK_API_BASE_URL,
64            http_client,
65        }
66    }
67    pub fn base_url(mut self, base_url: &'a str) -> Self {
68        self.base_url = base_url;
69        self
70    }
71
72    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
73        ClientBuilder {
74            api_key: self.api_key,
75            base_url: self.base_url,
76            http_client,
77        }
78    }
79
80    pub fn build(self) -> Client<T> {
81        Client {
82            base_url: self.base_url.to_string(),
83            api_key: self.api_key.to_string(),
84            http_client: self.http_client,
85        }
86    }
87}
88
89#[derive(Clone)]
90pub struct Client<T = reqwest::Client> {
91    pub base_url: String,
92    api_key: String,
93    http_client: T,
94}
95
96impl<T> std::fmt::Debug for Client<T>
97where
98    T: std::fmt::Debug,
99{
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("Client")
102            .field("base_url", &self.base_url)
103            .field("http_client", &self.http_client)
104            .field("api_key", &"<REDACTED>")
105            .finish()
106    }
107}
108
109impl<T> Client<T>
110where
111    T: HttpClientExt,
112{
113    fn req(
114        &self,
115        method: http_client::Method,
116        path: &str,
117    ) -> http_client::Result<http_client::Builder> {
118        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
119
120        http_client::with_bearer_auth(
121            http_client::Request::builder().method(method).uri(url),
122            &self.api_key,
123        )
124    }
125
126    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
127        self.req(http_client::Method::GET, path)
128    }
129
130    async fn send<U, R>(
131        &self,
132        req: http_client::Request<U>,
133    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
134    where
135        U: Into<Bytes> + Send,
136        R: From<Bytes> + Send + 'static,
137    {
138        self.http_client.send(req).await
139    }
140}
141
142impl Client<reqwest::Client> {
143    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
144        ClientBuilder::new(api_key)
145    }
146
147    pub fn new(api_key: &str) -> Self {
148        ClientBuilder::new(api_key).build()
149    }
150
151    pub fn from_env() -> Self {
152        <Self as ProviderClient>::from_env()
153    }
154}
155
156impl<T> ProviderClient for Client<T>
157where
158    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
159{
160    // If you prefer the environment variable approach:
161    fn from_env() -> Self {
162        let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
163        ClientBuilder::<T>::new(&api_key).build()
164    }
165
166    fn from_val(input: crate::client::ProviderValue) -> Self {
167        let crate::client::ProviderValue::Simple(api_key) = input else {
168            panic!("Incorrect provider value type")
169        };
170        ClientBuilder::<T>::new(&api_key).build()
171    }
172}
173
174impl<T> CompletionClient for Client<T>
175where
176    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
177{
178    type CompletionModel = CompletionModel<T>;
179
180    /// Creates a DeepSeek completion model with the given `model_name`.
181    fn completion_model(&self, model_name: &str) -> Self::CompletionModel {
182        CompletionModel {
183            client: self.clone(),
184            model: model_name.to_string(),
185        }
186    }
187}
188
189impl<T> VerifyClient for Client<T>
190where
191    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
192{
193    #[cfg_attr(feature = "worker", worker::send)]
194    async fn verify(&self) -> Result<(), VerifyError> {
195        let req = self
196            .get("/user/balance")?
197            .body(http_client::NoBody)
198            .map_err(http_client::Error::from)?;
199
200        let response = self.send(req).await?;
201
202        match response.status() {
203            reqwest::StatusCode::OK => Ok(()),
204            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
205            reqwest::StatusCode::INTERNAL_SERVER_ERROR
206            | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
207                let text = http_client::text(response).await?;
208                Err(VerifyError::ProviderError(text))
209            }
210            _ => {
211                // TODO: `HttpClientExt` equivalent
212                //response.error_for_status()?;
213                Ok(())
214            }
215        }
216    }
217}
218
219impl_conversion_traits!(
220    AsEmbeddings,
221    AsTranscription,
222    AsImageGeneration,
223    AsAudioGeneration for Client<T>
224);
225
226#[derive(Debug, Deserialize)]
227struct ApiErrorResponse {
228    message: String,
229}
230
231#[derive(Debug, Deserialize)]
232#[serde(untagged)]
233enum ApiResponse<T> {
234    Ok(T),
235    Err(ApiErrorResponse),
236}
237
238impl From<ApiErrorResponse> for CompletionError {
239    fn from(err: ApiErrorResponse) -> Self {
240        CompletionError::ProviderError(err.message)
241    }
242}
243
244/// The response shape from the DeepSeek API
245#[derive(Clone, Debug, Serialize, Deserialize)]
246pub struct CompletionResponse {
247    // We'll match the JSON:
248    pub choices: Vec<Choice>,
249    pub usage: Usage,
250    // you may want other fields
251}
252
253#[derive(Clone, Debug, Serialize, Deserialize, Default)]
254pub struct Usage {
255    pub completion_tokens: u32,
256    pub prompt_tokens: u32,
257    pub prompt_cache_hit_tokens: u32,
258    pub prompt_cache_miss_tokens: u32,
259    pub total_tokens: u32,
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub completion_tokens_details: Option<CompletionTokensDetails>,
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub prompt_tokens_details: Option<PromptTokensDetails>,
264}
265
266impl Usage {
267    fn new() -> Self {
268        Self {
269            completion_tokens: 0,
270            prompt_tokens: 0,
271            prompt_cache_hit_tokens: 0,
272            prompt_cache_miss_tokens: 0,
273            total_tokens: 0,
274            completion_tokens_details: None,
275            prompt_tokens_details: None,
276        }
277    }
278}
279
280#[derive(Clone, Debug, Serialize, Deserialize, Default)]
281pub struct CompletionTokensDetails {
282    #[serde(skip_serializing_if = "Option::is_none")]
283    pub reasoning_tokens: Option<u32>,
284}
285
286#[derive(Clone, Debug, Serialize, Deserialize, Default)]
287pub struct PromptTokensDetails {
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub cached_tokens: Option<u32>,
290}
291
292#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
293pub struct Choice {
294    pub index: usize,
295    pub message: Message,
296    pub logprobs: Option<serde_json::Value>,
297    pub finish_reason: String,
298}
299
300#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
301#[serde(tag = "role", rename_all = "lowercase")]
302pub enum Message {
303    System {
304        content: String,
305        #[serde(skip_serializing_if = "Option::is_none")]
306        name: Option<String>,
307    },
308    User {
309        content: String,
310        #[serde(skip_serializing_if = "Option::is_none")]
311        name: Option<String>,
312    },
313    Assistant {
314        content: String,
315        #[serde(skip_serializing_if = "Option::is_none")]
316        name: Option<String>,
317        #[serde(
318            default,
319            deserialize_with = "json_utils::null_or_vec",
320            skip_serializing_if = "Vec::is_empty"
321        )]
322        tool_calls: Vec<ToolCall>,
323    },
324    #[serde(rename = "tool")]
325    ToolResult {
326        tool_call_id: String,
327        content: String,
328    },
329}
330
331impl Message {
332    pub fn system(content: &str) -> Self {
333        Message::System {
334            content: content.to_owned(),
335            name: None,
336        }
337    }
338}
339
340impl From<message::ToolResult> for Message {
341    fn from(tool_result: message::ToolResult) -> Self {
342        let content = match tool_result.content.first() {
343            message::ToolResultContent::Text(text) => text.text,
344            message::ToolResultContent::Image(_) => String::from("[Image]"),
345        };
346
347        Message::ToolResult {
348            tool_call_id: tool_result.id,
349            content,
350        }
351    }
352}
353
354impl From<message::ToolCall> for ToolCall {
355    fn from(tool_call: message::ToolCall) -> Self {
356        Self {
357            id: tool_call.id,
358            // TODO: update index when we have it
359            index: 0,
360            r#type: ToolType::Function,
361            function: Function {
362                name: tool_call.function.name,
363                arguments: tool_call.function.arguments,
364            },
365        }
366    }
367}
368
369impl TryFrom<message::Message> for Vec<Message> {
370    type Error = message::MessageError;
371
372    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
373        match message {
374            message::Message::User { content } => {
375                // extract tool results
376                let mut messages = vec![];
377
378                let tool_results = content
379                    .clone()
380                    .into_iter()
381                    .filter_map(|content| match content {
382                        message::UserContent::ToolResult(tool_result) => {
383                            Some(Message::from(tool_result))
384                        }
385                        _ => None,
386                    })
387                    .collect::<Vec<_>>();
388
389                messages.extend(tool_results);
390
391                // extract text results
392                let text_messages = content
393                    .into_iter()
394                    .filter_map(|content| match content {
395                        message::UserContent::Text(text) => Some(Message::User {
396                            content: text.text,
397                            name: None,
398                        }),
399                        message::UserContent::Document(Document {
400                            data:
401                                DocumentSourceKind::Base64(content)
402                                | DocumentSourceKind::String(content),
403                            ..
404                        }) => Some(Message::User {
405                            content,
406                            name: None,
407                        }),
408                        _ => None,
409                    })
410                    .collect::<Vec<_>>();
411                messages.extend(text_messages);
412
413                Ok(messages)
414            }
415            message::Message::Assistant { content, .. } => {
416                let mut messages: Vec<Message> = vec![];
417
418                // extract text
419                let text_content = content
420                    .clone()
421                    .into_iter()
422                    .filter_map(|content| match content {
423                        message::AssistantContent::Text(text) => Some(Message::Assistant {
424                            content: text.text,
425                            name: None,
426                            tool_calls: vec![],
427                        }),
428                        _ => None,
429                    })
430                    .collect::<Vec<_>>();
431
432                messages.extend(text_content);
433
434                // extract tool calls
435                let tool_calls = content
436                    .clone()
437                    .into_iter()
438                    .filter_map(|content| match content {
439                        message::AssistantContent::ToolCall(tool_call) => {
440                            Some(ToolCall::from(tool_call))
441                        }
442                        _ => None,
443                    })
444                    .collect::<Vec<_>>();
445
446                // if we have tool calls, we add a new Assistant message with them
447                if !tool_calls.is_empty() {
448                    messages.push(Message::Assistant {
449                        content: "".to_string(),
450                        name: None,
451                        tool_calls,
452                    });
453                }
454
455                Ok(messages)
456            }
457        }
458    }
459}
460
461#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
462pub struct ToolCall {
463    pub id: String,
464    pub index: usize,
465    #[serde(default)]
466    pub r#type: ToolType,
467    pub function: Function,
468}
469
470#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
471pub struct Function {
472    pub name: String,
473    #[serde(with = "json_utils::stringified_json")]
474    pub arguments: serde_json::Value,
475}
476
477#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
478#[serde(rename_all = "lowercase")]
479pub enum ToolType {
480    #[default]
481    Function,
482}
483
484#[derive(Clone, Debug, Deserialize, Serialize)]
485pub struct ToolDefinition {
486    pub r#type: String,
487    pub function: completion::ToolDefinition,
488}
489
490impl From<crate::completion::ToolDefinition> for ToolDefinition {
491    fn from(tool: crate::completion::ToolDefinition) -> Self {
492        Self {
493            r#type: "function".into(),
494            function: tool,
495        }
496    }
497}
498
499impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
500    type Error = CompletionError;
501
502    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
503        let choice = response.choices.first().ok_or_else(|| {
504            CompletionError::ResponseError("Response contained no choices".to_owned())
505        })?;
506        let content = match &choice.message {
507            Message::Assistant {
508                content,
509                tool_calls,
510                ..
511            } => {
512                let mut content = if content.trim().is_empty() {
513                    vec![]
514                } else {
515                    vec![completion::AssistantContent::text(content)]
516                };
517
518                content.extend(
519                    tool_calls
520                        .iter()
521                        .map(|call| {
522                            completion::AssistantContent::tool_call(
523                                &call.id,
524                                &call.function.name,
525                                call.function.arguments.clone(),
526                            )
527                        })
528                        .collect::<Vec<_>>(),
529                );
530                Ok(content)
531            }
532            _ => Err(CompletionError::ResponseError(
533                "Response did not contain a valid message or tool call".into(),
534            )),
535        }?;
536
537        let choice = OneOrMany::many(content).map_err(|_| {
538            CompletionError::ResponseError(
539                "Response contained no message or tool call (empty)".to_owned(),
540            )
541        })?;
542
543        let usage = completion::Usage {
544            input_tokens: response.usage.prompt_tokens as u64,
545            output_tokens: response.usage.completion_tokens as u64,
546            total_tokens: response.usage.total_tokens as u64,
547        };
548
549        Ok(completion::CompletionResponse {
550            choice,
551            usage,
552            raw_response: response,
553        })
554    }
555}
556
557/// The struct implementing the `CompletionModel` trait
558#[derive(Clone)]
559pub struct CompletionModel<T = reqwest::Client> {
560    pub client: Client<T>,
561    pub model: String,
562}
563
564impl<T> CompletionModel<T> {
565    fn create_completion_request(
566        &self,
567        completion_request: CompletionRequest,
568    ) -> Result<serde_json::Value, CompletionError> {
569        // Build up the order of messages (context, chat_history, prompt)
570        let mut partial_history = vec![];
571
572        if let Some(docs) = completion_request.normalized_documents() {
573            partial_history.push(docs);
574        }
575
576        partial_history.extend(completion_request.chat_history);
577
578        // Initialize full history with preamble (or empty if non-existent)
579        let mut full_history: Vec<Message> = completion_request
580            .preamble
581            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
582
583        // Convert and extend the rest of the history
584        full_history.extend(
585            partial_history
586                .into_iter()
587                .map(message::Message::try_into)
588                .collect::<Result<Vec<Vec<Message>>, _>>()?
589                .into_iter()
590                .flatten()
591                .collect::<Vec<_>>(),
592        );
593
594        let tool_choice = completion_request
595            .tool_choice
596            .map(crate::providers::openrouter::ToolChoice::try_from)
597            .transpose()?;
598
599        let request = if completion_request.tools.is_empty() {
600            json!({
601                "model": self.model,
602                "messages": full_history,
603                "temperature": completion_request.temperature,
604            })
605        } else {
606            json!({
607                "model": self.model,
608                "messages": full_history,
609                "temperature": completion_request.temperature,
610                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
611                "tool_choice": tool_choice,
612            })
613        };
614
615        let request = if let Some(params) = completion_request.additional_params {
616            json_utils::merge(request, params)
617        } else {
618            request
619        };
620
621        Ok(request)
622    }
623}
624
625impl<T> completion::CompletionModel for CompletionModel<T>
626where
627    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
628{
629    type Response = CompletionResponse;
630    type StreamingResponse = StreamingCompletionResponse;
631
632    #[cfg_attr(feature = "worker", worker::send)]
633    async fn completion(
634        &self,
635        completion_request: CompletionRequest,
636    ) -> Result<
637        completion::CompletionResponse<CompletionResponse>,
638        crate::completion::CompletionError,
639    > {
640        let preamble = completion_request.preamble.clone();
641        let request = self.create_completion_request(completion_request)?;
642
643        let span = if tracing::Span::current().is_disabled() {
644            info_span!(
645                target: "rig::completions",
646                "chat",
647                gen_ai.operation.name = "chat",
648                gen_ai.provider.name = "deepseek",
649                gen_ai.request.model = self.model,
650                gen_ai.system_instructions = preamble,
651                gen_ai.response.id = tracing::field::Empty,
652                gen_ai.response.model = tracing::field::Empty,
653                gen_ai.usage.output_tokens = tracing::field::Empty,
654                gen_ai.usage.input_tokens = tracing::field::Empty,
655                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
656                gen_ai.output.messages = tracing::field::Empty,
657            )
658        } else {
659            tracing::Span::current()
660        };
661
662        tracing::debug!("DeepSeek completion request: {request:?}");
663
664        let body = serde_json::to_vec(&request)?;
665        let req = self
666            .client
667            .req(Method::POST, "/chat/completions")?
668            .header("Content-Type", "application/json")
669            .body(body)
670            .map_err(|e| CompletionError::HttpError(e.into()))?;
671
672        async move {
673            let response = self.client.http_client.send::<_, Bytes>(req).await?;
674            let status = response.status();
675            let response_body = response.into_body().into_future().await?.to_vec();
676
677            if status.is_success() {
678                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
679                    ApiResponse::Ok(response) => {
680                        let span = tracing::Span::current();
681                        span.record(
682                            "gen_ai.output.messages",
683                            serde_json::to_string(&response.choices).unwrap(),
684                        );
685                        span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
686                        span.record(
687                            "gen_ai.usage.output_tokens",
688                            response.usage.completion_tokens,
689                        );
690                        tracing::debug!(target: "rig", "DeepSeek completion output: {}", serde_json::to_string_pretty(&response_body)?);
691
692                        response.try_into()
693                    }
694                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
695                }
696            } else {
697                Err(CompletionError::ProviderError(
698                    String::from_utf8_lossy(&response_body).to_string()
699                ))
700            }
701        }
702        .instrument(span)
703        .await
704    }
705
706    #[cfg_attr(feature = "worker", worker::send)]
707    async fn stream(
708        &self,
709        completion_request: CompletionRequest,
710    ) -> Result<
711        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
712        CompletionError,
713    > {
714        let preamble = completion_request.preamble.clone();
715        let mut request = self.create_completion_request(completion_request)?;
716
717        request = merge(
718            request,
719            json!({"stream": true, "stream_options": {"include_usage": true}}),
720        );
721
722        let body = serde_json::to_vec(&request)?;
723
724        let req = self
725            .client
726            .req(Method::POST, "/chat/completions")?
727            .header("Content-Type", "application/json")
728            .body(body)
729            .map_err(|e| CompletionError::HttpError(e.into()))?;
730
731        let span = if tracing::Span::current().is_disabled() {
732            info_span!(
733                target: "rig::completions",
734                "chat_streaming",
735                gen_ai.operation.name = "chat_streaming",
736                gen_ai.provider.name = "deepseek",
737                gen_ai.request.model = self.model,
738                gen_ai.system_instructions = preamble,
739                gen_ai.response.id = tracing::field::Empty,
740                gen_ai.response.model = tracing::field::Empty,
741                gen_ai.usage.output_tokens = tracing::field::Empty,
742                gen_ai.usage.input_tokens = tracing::field::Empty,
743                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
744                gen_ai.output.messages = tracing::field::Empty,
745            )
746        } else {
747            tracing::Span::current()
748        };
749
750        tracing::Instrument::instrument(
751            send_compatible_streaming_request(self.client.http_client.clone(), req),
752            span,
753        )
754        .await
755    }
756}
757
758#[derive(Deserialize, Debug)]
759pub struct StreamingDelta {
760    #[serde(default)]
761    content: Option<String>,
762    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
763    tool_calls: Vec<StreamingToolCall>,
764    reasoning_content: Option<String>,
765}
766
767#[derive(Deserialize, Debug)]
768struct StreamingChoice {
769    delta: StreamingDelta,
770}
771
772#[derive(Deserialize, Debug)]
773struct StreamingCompletionChunk {
774    choices: Vec<StreamingChoice>,
775    usage: Option<Usage>,
776}
777
778#[derive(Clone, Deserialize, Serialize, Debug)]
779pub struct StreamingCompletionResponse {
780    pub usage: Usage,
781}
782
783impl GetTokenUsage for StreamingCompletionResponse {
784    fn token_usage(&self) -> Option<crate::completion::Usage> {
785        let mut usage = crate::completion::Usage::new();
786        usage.input_tokens = self.usage.prompt_tokens as u64;
787        usage.output_tokens = self.usage.completion_tokens as u64;
788        usage.total_tokens = self.usage.total_tokens as u64;
789
790        Some(usage)
791    }
792}
793
794pub async fn send_compatible_streaming_request<T>(
795    http_client: T,
796    req: Request<Vec<u8>>,
797) -> Result<
798    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
799    CompletionError,
800>
801where
802    T: HttpClientExt + Clone + 'static,
803{
804    let span = tracing::Span::current();
805    let mut event_source = GenericEventSource::new(http_client, req);
806
807    let stream = stream! {
808        let mut final_usage = Usage::new();
809        let mut text_response = String::new();
810        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
811
812        while let Some(event_result) = event_source.next().await {
813            match event_result {
814                Ok(Event::Open) => {
815                    tracing::trace!("SSE connection opened");
816                    continue;
817                }
818                Ok(Event::Message(message)) => {
819                    if message.data.trim().is_empty() || message.data == "[DONE]" {
820                        continue;
821                    }
822
823                    let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
824                    let Ok(data) = parsed else {
825                        let err = parsed.unwrap_err();
826                        tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
827                        continue;
828                    };
829
830                    if let Some(choice) = data.choices.first() {
831                        let delta = &choice.delta;
832
833                        if !delta.tool_calls.is_empty() {
834                            for tool_call in &delta.tool_calls {
835                                let function = &tool_call.function;
836
837                                // Start of tool call
838                                if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
839                                    && function.arguments.is_empty()
840                                {
841                                    let id = tool_call.id.clone().unwrap_or_default();
842                                    let name = function.name.clone().unwrap();
843                                    calls.insert(tool_call.index, (id, name, String::new()));
844                                }
845                                // Continuation of tool call
846                                else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
847                                    && !function.arguments.is_empty()
848                                {
849                                    if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
850                                        let combined = format!("{}{}", existing_args, function.arguments);
851                                        calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
852                                    } else {
853                                        tracing::debug!("Partial tool call received but tool call was never started.");
854                                    }
855                                }
856                                // Complete tool call
857                                else {
858                                    let id = tool_call.id.clone().unwrap_or_default();
859                                    let name = function.name.clone().unwrap_or_default();
860                                    let arguments_str = function.arguments.clone();
861
862                                    let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
863                                        tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
864                                        continue;
865                                    };
866
867                                    yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
868                                        id,
869                                        name,
870                                        arguments: arguments_json,
871                                        call_id: None,
872                                    });
873                                }
874                            }
875                        }
876
877                        // DeepSeek-specific reasoning stream
878                        if let Some(content) = &delta.reasoning_content {
879                            yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
880                                reasoning: content.to_string(),
881                                id: None,
882                                signature: None,
883                            });
884                        }
885
886                        if let Some(content) = &delta.content {
887                            text_response += content;
888                            yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
889                        }
890                    }
891
892                    if let Some(usage) = data.usage {
893                        final_usage = usage.clone();
894                    }
895                }
896                Err(crate::http_client::Error::StreamEnded) => {
897                    break;
898                }
899                Err(err) => {
900                    tracing::error!(?err, "SSE error");
901                    yield Err(CompletionError::ResponseError(err.to_string()));
902                    break;
903                }
904            }
905        }
906
907        event_source.close();
908
909        let mut tool_calls = Vec::new();
910        // Flush accumulated tool calls
911        for (index, (id, name, arguments)) in calls {
912            let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
913                continue;
914            };
915
916            tool_calls.push(ToolCall {
917                id: id.clone(),
918                index,
919                r#type: ToolType::Function,
920                function: Function {
921                    name: name.clone(),
922                    arguments: arguments_json.clone()
923                }
924            });
925            yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
926                id,
927                name,
928                arguments: arguments_json,
929                call_id: None,
930            });
931        }
932
933        let message = Message::Assistant {
934            content: text_response,
935            name: None,
936            tool_calls
937        };
938
939        span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
940
941        yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
942            StreamingCompletionResponse { usage: final_usage.clone() }
943        ));
944    };
945
946    Ok(crate::streaming::StreamingCompletionResponse::stream(
947        Box::pin(stream),
948    ))
949}
950
951// ================================================================
952// DeepSeek Completion API
953// ================================================================
954
955/// `deepseek-chat` completion model
956pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
957/// `deepseek-reasoner` completion model
958pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
959
960// Tests
961#[cfg(test)]
962mod tests {
963
964    use super::*;
965
966    #[test]
967    fn test_deserialize_vec_choice() {
968        let data = r#"[{
969            "finish_reason": "stop",
970            "index": 0,
971            "logprobs": null,
972            "message":{"role":"assistant","content":"Hello, world!"}
973            }]"#;
974
975        let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
976        assert_eq!(choices.len(), 1);
977        match &choices.first().unwrap().message {
978            Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
979            _ => panic!("Expected assistant message"),
980        }
981    }
982
983    #[test]
984    fn test_deserialize_deepseek_response() {
985        let data = r#"{
986            "choices":[{
987                "finish_reason": "stop",
988                "index": 0,
989                "logprobs": null,
990                "message":{"role":"assistant","content":"Hello, world!"}
991            }],
992            "usage": {
993                "completion_tokens": 0,
994                "prompt_tokens": 0,
995                "prompt_cache_hit_tokens": 0,
996                "prompt_cache_miss_tokens": 0,
997                "total_tokens": 0
998            }
999        }"#;
1000
1001        let jd = &mut serde_json::Deserializer::from_str(data);
1002        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
1003        match result {
1004            Ok(response) => match &response.choices.first().unwrap().message {
1005                Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
1006                _ => panic!("Expected assistant message"),
1007            },
1008            Err(err) => {
1009                panic!("Deserialization error at {}: {}", err.path(), err);
1010            }
1011        }
1012    }
1013
1014    #[test]
1015    fn test_deserialize_example_response() {
1016        let data = r#"
1017        {
1018            "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
1019            "object": "chat.completion",
1020            "created": 0,
1021            "model": "deepseek-chat",
1022            "choices": [
1023                {
1024                    "index": 0,
1025                    "message": {
1026                        "role": "assistant",
1027                        "content": "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
1028                    },
1029                    "logprobs": null,
1030                    "finish_reason": "stop"
1031                }
1032            ],
1033            "usage": {
1034                "prompt_tokens": 13,
1035                "completion_tokens": 32,
1036                "total_tokens": 45,
1037                "prompt_tokens_details": {
1038                    "cached_tokens": 0
1039                },
1040                "prompt_cache_hit_tokens": 0,
1041                "prompt_cache_miss_tokens": 13
1042            },
1043            "system_fingerprint": "fp_4b6881f2c5"
1044        }
1045        "#;
1046        let jd = &mut serde_json::Deserializer::from_str(data);
1047        let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
1048
1049        match result {
1050            Ok(response) => match &response.choices.first().unwrap().message {
1051                Message::Assistant { content, .. } => assert_eq!(
1052                    content,
1053                    "Why don’t skeletons fight each other?  \nBecause they don’t have the guts! 😄"
1054                ),
1055                _ => panic!("Expected assistant message"),
1056            },
1057            Err(err) => {
1058                panic!("Deserialization error at {}: {}", err.path(), err);
1059            }
1060        }
1061    }
1062
1063    #[test]
1064    fn test_serialize_deserialize_tool_call_message() {
1065        let tool_call_choice_json = r#"
1066            {
1067              "finish_reason": "tool_calls",
1068              "index": 0,
1069              "logprobs": null,
1070              "message": {
1071                "content": "",
1072                "role": "assistant",
1073                "tool_calls": [
1074                  {
1075                    "function": {
1076                      "arguments": "{\"x\":2,\"y\":5}",
1077                      "name": "subtract"
1078                    },
1079                    "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1080                    "index": 0,
1081                    "type": "function"
1082                  }
1083                ]
1084              }
1085            }
1086        "#;
1087
1088        let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1089
1090        let expected_choice: Choice = Choice {
1091            finish_reason: "tool_calls".to_string(),
1092            index: 0,
1093            logprobs: None,
1094            message: Message::Assistant {
1095                content: "".to_string(),
1096                name: None,
1097                tool_calls: vec![ToolCall {
1098                    id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1099                    function: Function {
1100                        name: "subtract".to_string(),
1101                        arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1102                    },
1103                    index: 0,
1104                    r#type: ToolType::Function,
1105                }],
1106            },
1107        };
1108
1109        assert_eq!(choice, expected_choice);
1110    }
1111}