1use async_stream::stream;
13use bytes::Bytes;
14use futures::StreamExt;
15use reqwest_eventsource::{Event, RequestBuilderExt};
16use std::collections::HashMap;
17use tracing::{Instrument, info_span};
18
19use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
20use crate::completion::GetTokenUsage;
21use crate::http_client::{self, HttpClientExt};
22use crate::json_utils::merge;
23use crate::message::{Document, DocumentSourceKind};
24use crate::{
25 OneOrMany,
26 completion::{self, CompletionError, CompletionRequest},
27 impl_conversion_traits, json_utils, message,
28};
29use serde::{Deserialize, Serialize};
30use serde_json::json;
31
32use super::openai::StreamingToolCall;
33
34const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
38
39pub struct ClientBuilder<'a, T = reqwest::Client> {
40 api_key: &'a str,
41 base_url: &'a str,
42 http_client: T,
43}
44
45impl<'a, T> ClientBuilder<'a, T>
46where
47 T: Default,
48{
49 pub fn new(api_key: &'a str) -> Self {
50 Self {
51 api_key,
52 base_url: DEEPSEEK_API_BASE_URL,
53 http_client: Default::default(),
54 }
55 }
56}
57
58impl<'a, T> ClientBuilder<'a, T> {
59 pub fn base_url(mut self, base_url: &'a str) -> Self {
60 self.base_url = base_url;
61 self
62 }
63
64 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
65 ClientBuilder {
66 api_key: self.api_key,
67 base_url: self.base_url,
68 http_client,
69 }
70 }
71
72 pub fn build(self) -> Client<T> {
73 Client {
74 base_url: self.base_url.to_string(),
75 api_key: self.api_key.to_string(),
76 http_client: self.http_client,
77 }
78 }
79}
80
81#[derive(Clone)]
82pub struct Client<T = reqwest::Client> {
83 pub base_url: String,
84 api_key: String,
85 http_client: T,
86}
87
88impl<T> std::fmt::Debug for Client<T>
89where
90 T: std::fmt::Debug,
91{
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("Client")
94 .field("base_url", &self.base_url)
95 .field("http_client", &self.http_client)
96 .field("api_key", &"<REDACTED>")
97 .finish()
98 }
99}
100
101impl<T> Client<T>
102where
103 T: Default,
104{
105 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
116 ClientBuilder::new(api_key)
117 }
118
119 pub fn new(api_key: &str) -> Self {
124 Self::builder(api_key).build()
125 }
126}
127
128impl<T> Client<T>
129where
130 T: HttpClientExt,
131{
132 fn req(
133 &self,
134 method: http_client::Method,
135 path: &str,
136 ) -> http_client::Result<http_client::Builder> {
137 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
138
139 http_client::with_bearer_auth(
140 http_client::Request::builder().method(method).uri(url),
141 &self.api_key,
142 )
143 }
144
145 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
146 self.req(http_client::Method::GET, path)
147 }
148
149 async fn send<U, R>(
150 &self,
151 req: http_client::Request<U>,
152 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
153 where
154 U: Into<Bytes> + Send,
155 R: From<Bytes> + Send + 'static,
156 {
157 self.http_client.send(req).await
158 }
159}
160
161impl Client<reqwest::Client> {
162 fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
163 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
164
165 self.http_client.post(url).bearer_auth(&self.api_key)
166 }
167}
168
169impl ProviderClient for Client<reqwest::Client> {
170 fn from_env() -> Self {
172 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
173 Self::new(&api_key)
174 }
175
176 fn from_val(input: crate::client::ProviderValue) -> Self {
177 let crate::client::ProviderValue::Simple(api_key) = input else {
178 panic!("Incorrect provider value type")
179 };
180 Self::new(&api_key)
181 }
182}
183
184impl CompletionClient for Client<reqwest::Client> {
185 type CompletionModel = CompletionModel<reqwest::Client>;
186
187 fn completion_model(&self, model_name: &str) -> CompletionModel<reqwest::Client> {
189 CompletionModel {
190 client: self.clone(),
191 model: model_name.to_string(),
192 }
193 }
194}
195
196impl VerifyClient for Client<reqwest::Client> {
197 #[cfg_attr(feature = "worker", worker::send)]
198 async fn verify(&self) -> Result<(), VerifyError> {
199 let req = self
200 .get("/user/balance")?
201 .body(http_client::NoBody)
202 .map_err(http_client::Error::from)?;
203
204 let response = self.send(req).await?;
205
206 match response.status() {
207 reqwest::StatusCode::OK => Ok(()),
208 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
209 reqwest::StatusCode::INTERNAL_SERVER_ERROR
210 | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
211 let text = http_client::text(response).await?;
212 Err(VerifyError::ProviderError(text))
213 }
214 _ => {
215 Ok(())
218 }
219 }
220 }
221}
222
223impl_conversion_traits!(
224 AsEmbeddings,
225 AsTranscription,
226 AsImageGeneration,
227 AsAudioGeneration for Client<T>
228);
229
230#[derive(Debug, Deserialize)]
231struct ApiErrorResponse {
232 message: String,
233}
234
235#[derive(Debug, Deserialize)]
236#[serde(untagged)]
237enum ApiResponse<T> {
238 Ok(T),
239 Err(ApiErrorResponse),
240}
241
242impl From<ApiErrorResponse> for CompletionError {
243 fn from(err: ApiErrorResponse) -> Self {
244 CompletionError::ProviderError(err.message)
245 }
246}
247
248#[derive(Clone, Debug, Serialize, Deserialize)]
250pub struct CompletionResponse {
251 pub choices: Vec<Choice>,
253 pub usage: Usage,
254 }
256
257#[derive(Clone, Debug, Serialize, Deserialize, Default)]
258pub struct Usage {
259 pub completion_tokens: u32,
260 pub prompt_tokens: u32,
261 pub prompt_cache_hit_tokens: u32,
262 pub prompt_cache_miss_tokens: u32,
263 pub total_tokens: u32,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 pub completion_tokens_details: Option<CompletionTokensDetails>,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 pub prompt_tokens_details: Option<PromptTokensDetails>,
268}
269
270impl Usage {
271 fn new() -> Self {
272 Self {
273 completion_tokens: 0,
274 prompt_tokens: 0,
275 prompt_cache_hit_tokens: 0,
276 prompt_cache_miss_tokens: 0,
277 total_tokens: 0,
278 completion_tokens_details: None,
279 prompt_tokens_details: None,
280 }
281 }
282}
283
284#[derive(Clone, Debug, Serialize, Deserialize, Default)]
285pub struct CompletionTokensDetails {
286 #[serde(skip_serializing_if = "Option::is_none")]
287 pub reasoning_tokens: Option<u32>,
288}
289
290#[derive(Clone, Debug, Serialize, Deserialize, Default)]
291pub struct PromptTokensDetails {
292 #[serde(skip_serializing_if = "Option::is_none")]
293 pub cached_tokens: Option<u32>,
294}
295
296#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
297pub struct Choice {
298 pub index: usize,
299 pub message: Message,
300 pub logprobs: Option<serde_json::Value>,
301 pub finish_reason: String,
302}
303
304#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
305#[serde(tag = "role", rename_all = "lowercase")]
306pub enum Message {
307 System {
308 content: String,
309 #[serde(skip_serializing_if = "Option::is_none")]
310 name: Option<String>,
311 },
312 User {
313 content: String,
314 #[serde(skip_serializing_if = "Option::is_none")]
315 name: Option<String>,
316 },
317 Assistant {
318 content: String,
319 #[serde(skip_serializing_if = "Option::is_none")]
320 name: Option<String>,
321 #[serde(
322 default,
323 deserialize_with = "json_utils::null_or_vec",
324 skip_serializing_if = "Vec::is_empty"
325 )]
326 tool_calls: Vec<ToolCall>,
327 },
328 #[serde(rename = "tool")]
329 ToolResult {
330 tool_call_id: String,
331 content: String,
332 },
333}
334
335impl Message {
336 pub fn system(content: &str) -> Self {
337 Message::System {
338 content: content.to_owned(),
339 name: None,
340 }
341 }
342}
343
344impl From<message::ToolResult> for Message {
345 fn from(tool_result: message::ToolResult) -> Self {
346 let content = match tool_result.content.first() {
347 message::ToolResultContent::Text(text) => text.text,
348 message::ToolResultContent::Image(_) => String::from("[Image]"),
349 };
350
351 Message::ToolResult {
352 tool_call_id: tool_result.id,
353 content,
354 }
355 }
356}
357
358impl From<message::ToolCall> for ToolCall {
359 fn from(tool_call: message::ToolCall) -> Self {
360 Self {
361 id: tool_call.id,
362 index: 0,
364 r#type: ToolType::Function,
365 function: Function {
366 name: tool_call.function.name,
367 arguments: tool_call.function.arguments,
368 },
369 }
370 }
371}
372
373impl TryFrom<message::Message> for Vec<Message> {
374 type Error = message::MessageError;
375
376 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
377 match message {
378 message::Message::User { content } => {
379 let mut messages = vec![];
381
382 let tool_results = content
383 .clone()
384 .into_iter()
385 .filter_map(|content| match content {
386 message::UserContent::ToolResult(tool_result) => {
387 Some(Message::from(tool_result))
388 }
389 _ => None,
390 })
391 .collect::<Vec<_>>();
392
393 messages.extend(tool_results);
394
395 let text_messages = content
397 .into_iter()
398 .filter_map(|content| match content {
399 message::UserContent::Text(text) => Some(Message::User {
400 content: text.text,
401 name: None,
402 }),
403 message::UserContent::Document(Document {
404 data:
405 DocumentSourceKind::Base64(content)
406 | DocumentSourceKind::String(content),
407 ..
408 }) => Some(Message::User {
409 content,
410 name: None,
411 }),
412 _ => None,
413 })
414 .collect::<Vec<_>>();
415 messages.extend(text_messages);
416
417 Ok(messages)
418 }
419 message::Message::Assistant { content, .. } => {
420 let mut messages: Vec<Message> = vec![];
421
422 let text_content = content
424 .clone()
425 .into_iter()
426 .filter_map(|content| match content {
427 message::AssistantContent::Text(text) => Some(Message::Assistant {
428 content: text.text,
429 name: None,
430 tool_calls: vec![],
431 }),
432 _ => None,
433 })
434 .collect::<Vec<_>>();
435
436 messages.extend(text_content);
437
438 let tool_calls = content
440 .clone()
441 .into_iter()
442 .filter_map(|content| match content {
443 message::AssistantContent::ToolCall(tool_call) => {
444 Some(ToolCall::from(tool_call))
445 }
446 _ => None,
447 })
448 .collect::<Vec<_>>();
449
450 if !tool_calls.is_empty() {
452 messages.push(Message::Assistant {
453 content: "".to_string(),
454 name: None,
455 tool_calls,
456 });
457 }
458
459 Ok(messages)
460 }
461 }
462 }
463}
464
465#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
466pub struct ToolCall {
467 pub id: String,
468 pub index: usize,
469 #[serde(default)]
470 pub r#type: ToolType,
471 pub function: Function,
472}
473
474#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
475pub struct Function {
476 pub name: String,
477 #[serde(with = "json_utils::stringified_json")]
478 pub arguments: serde_json::Value,
479}
480
481#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
482#[serde(rename_all = "lowercase")]
483pub enum ToolType {
484 #[default]
485 Function,
486}
487
488#[derive(Clone, Debug, Deserialize, Serialize)]
489pub struct ToolDefinition {
490 pub r#type: String,
491 pub function: completion::ToolDefinition,
492}
493
494impl From<crate::completion::ToolDefinition> for ToolDefinition {
495 fn from(tool: crate::completion::ToolDefinition) -> Self {
496 Self {
497 r#type: "function".into(),
498 function: tool,
499 }
500 }
501}
502
503impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
504 type Error = CompletionError;
505
506 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
507 let choice = response.choices.first().ok_or_else(|| {
508 CompletionError::ResponseError("Response contained no choices".to_owned())
509 })?;
510 let content = match &choice.message {
511 Message::Assistant {
512 content,
513 tool_calls,
514 ..
515 } => {
516 let mut content = if content.trim().is_empty() {
517 vec![]
518 } else {
519 vec![completion::AssistantContent::text(content)]
520 };
521
522 content.extend(
523 tool_calls
524 .iter()
525 .map(|call| {
526 completion::AssistantContent::tool_call(
527 &call.id,
528 &call.function.name,
529 call.function.arguments.clone(),
530 )
531 })
532 .collect::<Vec<_>>(),
533 );
534 Ok(content)
535 }
536 _ => Err(CompletionError::ResponseError(
537 "Response did not contain a valid message or tool call".into(),
538 )),
539 }?;
540
541 let choice = OneOrMany::many(content).map_err(|_| {
542 CompletionError::ResponseError(
543 "Response contained no message or tool call (empty)".to_owned(),
544 )
545 })?;
546
547 let usage = completion::Usage {
548 input_tokens: response.usage.prompt_tokens as u64,
549 output_tokens: response.usage.completion_tokens as u64,
550 total_tokens: response.usage.total_tokens as u64,
551 };
552
553 Ok(completion::CompletionResponse {
554 choice,
555 usage,
556 raw_response: response,
557 })
558 }
559}
560
561#[derive(Clone)]
563pub struct CompletionModel<T = reqwest::Client> {
564 pub client: Client<T>,
565 pub model: String,
566}
567
568impl<T> CompletionModel<T> {
569 fn create_completion_request(
570 &self,
571 completion_request: CompletionRequest,
572 ) -> Result<serde_json::Value, CompletionError> {
573 let mut partial_history = vec![];
575
576 if let Some(docs) = completion_request.normalized_documents() {
577 partial_history.push(docs);
578 }
579
580 partial_history.extend(completion_request.chat_history);
581
582 let mut full_history: Vec<Message> = completion_request
584 .preamble
585 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
586
587 full_history.extend(
589 partial_history
590 .into_iter()
591 .map(message::Message::try_into)
592 .collect::<Result<Vec<Vec<Message>>, _>>()?
593 .into_iter()
594 .flatten()
595 .collect::<Vec<_>>(),
596 );
597
598 let tool_choice = completion_request
599 .tool_choice
600 .map(crate::providers::openrouter::ToolChoice::try_from)
601 .transpose()?;
602
603 let request = if completion_request.tools.is_empty() {
604 json!({
605 "model": self.model,
606 "messages": full_history,
607 "temperature": completion_request.temperature,
608 })
609 } else {
610 json!({
611 "model": self.model,
612 "messages": full_history,
613 "temperature": completion_request.temperature,
614 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
615 "tool_choice": tool_choice,
616 })
617 };
618
619 let request = if let Some(params) = completion_request.additional_params {
620 json_utils::merge(request, params)
621 } else {
622 request
623 };
624
625 Ok(request)
626 }
627}
628
629impl completion::CompletionModel for CompletionModel<reqwest::Client> {
630 type Response = CompletionResponse;
631 type StreamingResponse = StreamingCompletionResponse;
632
633 #[cfg_attr(feature = "worker", worker::send)]
634 async fn completion(
635 &self,
636 completion_request: CompletionRequest,
637 ) -> Result<
638 completion::CompletionResponse<CompletionResponse>,
639 crate::completion::CompletionError,
640 > {
641 let preamble = completion_request.preamble.clone();
642 let request = self.create_completion_request(completion_request)?;
643
644 let span = if tracing::Span::current().is_disabled() {
645 info_span!(
646 target: "rig::completions",
647 "chat",
648 gen_ai.operation.name = "chat",
649 gen_ai.provider.name = "deepseek",
650 gen_ai.request.model = self.model,
651 gen_ai.system_instructions = preamble,
652 gen_ai.response.id = tracing::field::Empty,
653 gen_ai.response.model = tracing::field::Empty,
654 gen_ai.usage.output_tokens = tracing::field::Empty,
655 gen_ai.usage.input_tokens = tracing::field::Empty,
656 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
657 gen_ai.output.messages = tracing::field::Empty,
658 )
659 } else {
660 tracing::Span::current()
661 };
662
663 tracing::debug!("DeepSeek completion request: {request:?}");
664
665 async move {
666 let response = self
667 .client
668 .reqwest_post("/chat/completions")
669 .json(&request)
670 .send()
671 .await
672 .map_err(|e| http_client::Error::Instance(e.into()))?;
673
674 if response.status().is_success() {
675 let t = response
676 .text()
677 .await
678 .map_err(|e| http_client::Error::Instance(e.into()))?;
679
680 tracing::debug!(target: "rig", "DeepSeek completion: {t}");
681
682 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
683 ApiResponse::Ok(response) => {
684 let span = tracing::Span::current();
685 span.record(
686 "gen_ai.output.messages",
687 serde_json::to_string(&response.choices).unwrap(),
688 );
689 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
690 span.record(
691 "gen_ai.usage.output_tokens",
692 response.usage.completion_tokens,
693 );
694 response.try_into()
695 }
696 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
697 }
698 } else {
699 Err(CompletionError::ProviderError(
700 response
701 .text()
702 .await
703 .map_err(|e| http_client::Error::Instance(e.into()))?,
704 ))
705 }
706 }
707 .instrument(span)
708 .await
709 }
710
711 #[cfg_attr(feature = "worker", worker::send)]
712 async fn stream(
713 &self,
714 completion_request: CompletionRequest,
715 ) -> Result<
716 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
717 CompletionError,
718 > {
719 let preamble = completion_request.preamble.clone();
720 let mut request = self.create_completion_request(completion_request)?;
721
722 request = merge(
723 request,
724 json!({"stream": true, "stream_options": {"include_usage": true}}),
725 );
726
727 let builder = self.client.reqwest_post("/chat/completions").json(&request);
728
729 let span = if tracing::Span::current().is_disabled() {
730 info_span!(
731 target: "rig::completions",
732 "chat_streaming",
733 gen_ai.operation.name = "chat_streaming",
734 gen_ai.provider.name = "deepseek",
735 gen_ai.request.model = self.model,
736 gen_ai.system_instructions = preamble,
737 gen_ai.response.id = tracing::field::Empty,
738 gen_ai.response.model = tracing::field::Empty,
739 gen_ai.usage.output_tokens = tracing::field::Empty,
740 gen_ai.usage.input_tokens = tracing::field::Empty,
741 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
742 gen_ai.output.messages = tracing::field::Empty,
743 )
744 } else {
745 tracing::Span::current()
746 };
747
748 tracing::Instrument::instrument(send_compatible_streaming_request(builder), span).await
749 }
750}
751
752#[derive(Deserialize, Debug)]
753pub struct StreamingDelta {
754 #[serde(default)]
755 content: Option<String>,
756 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
757 tool_calls: Vec<StreamingToolCall>,
758 reasoning_content: Option<String>,
759}
760
761#[derive(Deserialize, Debug)]
762struct StreamingChoice {
763 delta: StreamingDelta,
764}
765
766#[derive(Deserialize, Debug)]
767struct StreamingCompletionChunk {
768 choices: Vec<StreamingChoice>,
769 usage: Option<Usage>,
770}
771
772#[derive(Clone, Deserialize, Serialize, Debug)]
773pub struct StreamingCompletionResponse {
774 pub usage: Usage,
775}
776
777impl GetTokenUsage for StreamingCompletionResponse {
778 fn token_usage(&self) -> Option<crate::completion::Usage> {
779 let mut usage = crate::completion::Usage::new();
780 usage.input_tokens = self.usage.prompt_tokens as u64;
781 usage.output_tokens = self.usage.completion_tokens as u64;
782 usage.total_tokens = self.usage.total_tokens as u64;
783
784 Some(usage)
785 }
786}
787
788pub async fn send_compatible_streaming_request(
789 request_builder: reqwest::RequestBuilder,
790) -> Result<
791 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
792 CompletionError,
793> {
794 let span = tracing::Span::current();
795 let mut event_source = request_builder
796 .eventsource()
797 .expect("Cloning request must succeed");
798
799 let stream = stream! {
800 let mut final_usage = Usage::new();
801 let mut text_response = String::new();
802 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
803
804 while let Some(event_result) = event_source.next().await {
805 match event_result {
806 Ok(Event::Open) => {
807 tracing::trace!("SSE connection opened");
808 continue;
809 }
810 Ok(Event::Message(message)) => {
811 if message.data.trim().is_empty() || message.data == "[DONE]" {
812 continue;
813 }
814
815 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
816 let Ok(data) = parsed else {
817 let err = parsed.unwrap_err();
818 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
819 continue;
820 };
821
822 if let Some(choice) = data.choices.first() {
823 let delta = &choice.delta;
824
825 if !delta.tool_calls.is_empty() {
826 for tool_call in &delta.tool_calls {
827 let function = &tool_call.function;
828
829 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
831 && function.arguments.is_empty()
832 {
833 let id = tool_call.id.clone().unwrap_or_default();
834 let name = function.name.clone().unwrap();
835 calls.insert(tool_call.index, (id, name, String::new()));
836 }
837 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
839 && !function.arguments.is_empty()
840 {
841 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
842 let combined = format!("{}{}", existing_args, function.arguments);
843 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
844 } else {
845 tracing::debug!("Partial tool call received but tool call was never started.");
846 }
847 }
848 else {
850 let id = tool_call.id.clone().unwrap_or_default();
851 let name = function.name.clone().unwrap_or_default();
852 let arguments_str = function.arguments.clone();
853
854 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
855 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
856 continue;
857 };
858
859 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
860 id,
861 name,
862 arguments: arguments_json,
863 call_id: None,
864 });
865 }
866 }
867 }
868
869 if let Some(content) = &delta.reasoning_content {
871 yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
872 reasoning: content.to_string(),
873 id: None,
874 });
875 }
876
877 if let Some(content) = &delta.content {
878 text_response += content;
879 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
880 }
881 }
882
883 if let Some(usage) = data.usage {
884 final_usage = usage.clone();
885 }
886 }
887 Err(reqwest_eventsource::Error::StreamEnded) => {
888 break;
889 }
890 Err(err) => {
891 tracing::error!(?err, "SSE error");
892 yield Err(CompletionError::ResponseError(err.to_string()));
893 break;
894 }
895 }
896 }
897
898 let mut tool_calls = Vec::new();
899 for (index, (id, name, arguments)) in calls {
901 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
902 continue;
903 };
904
905 tool_calls.push(ToolCall {
906 id: id.clone(),
907 index,
908 r#type: ToolType::Function,
909 function: Function {
910 name: name.clone(),
911 arguments: arguments_json.clone()
912 }
913 });
914 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
915 id,
916 name,
917 arguments: arguments_json,
918 call_id: None,
919 });
920 }
921
922 let message = Message::Assistant {
923 content: text_response,
924 name: None,
925 tool_calls
926 };
927
928 span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
929
930 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
931 StreamingCompletionResponse { usage: final_usage.clone() }
932 ));
933 };
934
935 Ok(crate::streaming::StreamingCompletionResponse::stream(
936 Box::pin(stream),
937 ))
938}
939
940pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
946pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
948
949#[cfg(test)]
951mod tests {
952
953 use super::*;
954
955 #[test]
956 fn test_deserialize_vec_choice() {
957 let data = r#"[{
958 "finish_reason": "stop",
959 "index": 0,
960 "logprobs": null,
961 "message":{"role":"assistant","content":"Hello, world!"}
962 }]"#;
963
964 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
965 assert_eq!(choices.len(), 1);
966 match &choices.first().unwrap().message {
967 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
968 _ => panic!("Expected assistant message"),
969 }
970 }
971
972 #[test]
973 fn test_deserialize_deepseek_response() {
974 let data = r#"{
975 "choices":[{
976 "finish_reason": "stop",
977 "index": 0,
978 "logprobs": null,
979 "message":{"role":"assistant","content":"Hello, world!"}
980 }],
981 "usage": {
982 "completion_tokens": 0,
983 "prompt_tokens": 0,
984 "prompt_cache_hit_tokens": 0,
985 "prompt_cache_miss_tokens": 0,
986 "total_tokens": 0
987 }
988 }"#;
989
990 let jd = &mut serde_json::Deserializer::from_str(data);
991 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
992 match result {
993 Ok(response) => match &response.choices.first().unwrap().message {
994 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
995 _ => panic!("Expected assistant message"),
996 },
997 Err(err) => {
998 panic!("Deserialization error at {}: {}", err.path(), err);
999 }
1000 }
1001 }
1002
1003 #[test]
1004 fn test_deserialize_example_response() {
1005 let data = r#"
1006 {
1007 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
1008 "object": "chat.completion",
1009 "created": 0,
1010 "model": "deepseek-chat",
1011 "choices": [
1012 {
1013 "index": 0,
1014 "message": {
1015 "role": "assistant",
1016 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
1017 },
1018 "logprobs": null,
1019 "finish_reason": "stop"
1020 }
1021 ],
1022 "usage": {
1023 "prompt_tokens": 13,
1024 "completion_tokens": 32,
1025 "total_tokens": 45,
1026 "prompt_tokens_details": {
1027 "cached_tokens": 0
1028 },
1029 "prompt_cache_hit_tokens": 0,
1030 "prompt_cache_miss_tokens": 13
1031 },
1032 "system_fingerprint": "fp_4b6881f2c5"
1033 }
1034 "#;
1035 let jd = &mut serde_json::Deserializer::from_str(data);
1036 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
1037
1038 match result {
1039 Ok(response) => match &response.choices.first().unwrap().message {
1040 Message::Assistant { content, .. } => assert_eq!(
1041 content,
1042 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
1043 ),
1044 _ => panic!("Expected assistant message"),
1045 },
1046 Err(err) => {
1047 panic!("Deserialization error at {}: {}", err.path(), err);
1048 }
1049 }
1050 }
1051
1052 #[test]
1053 fn test_serialize_deserialize_tool_call_message() {
1054 let tool_call_choice_json = r#"
1055 {
1056 "finish_reason": "tool_calls",
1057 "index": 0,
1058 "logprobs": null,
1059 "message": {
1060 "content": "",
1061 "role": "assistant",
1062 "tool_calls": [
1063 {
1064 "function": {
1065 "arguments": "{\"x\":2,\"y\":5}",
1066 "name": "subtract"
1067 },
1068 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1069 "index": 0,
1070 "type": "function"
1071 }
1072 ]
1073 }
1074 }
1075 "#;
1076
1077 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1078
1079 let expected_choice: Choice = Choice {
1080 finish_reason: "tool_calls".to_string(),
1081 index: 0,
1082 logprobs: None,
1083 message: Message::Assistant {
1084 content: "".to_string(),
1085 name: None,
1086 tool_calls: vec![ToolCall {
1087 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1088 function: Function {
1089 name: "subtract".to_string(),
1090 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1091 },
1092 index: 0,
1093 r#type: ToolType::Function,
1094 }],
1095 },
1096 };
1097
1098 assert_eq!(choice, expected_choice);
1099 }
1100}