1use async_stream::stream;
13use bytes::Bytes;
14use futures::StreamExt;
15use http::{Method, Request};
16use std::collections::HashMap;
17use tracing::{Instrument, info_span};
18
19use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
20use crate::completion::GetTokenUsage;
21use crate::http_client::sse::{Event, GenericEventSource};
22use crate::http_client::{self, HttpClientExt};
23use crate::json_utils::merge;
24use crate::message::{Document, DocumentSourceKind};
25use crate::{
26 OneOrMany,
27 completion::{self, CompletionError, CompletionRequest},
28 impl_conversion_traits, json_utils, message,
29};
30use serde::{Deserialize, Serialize};
31use serde_json::json;
32
33use super::openai::StreamingToolCall;
34
35const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
39
40pub struct ClientBuilder<'a, T = reqwest::Client> {
41 api_key: &'a str,
42 base_url: &'a str,
43 http_client: T,
44}
45
46impl<'a, T> ClientBuilder<'a, T>
47where
48 T: Default,
49{
50 pub fn new(api_key: &'a str) -> Self {
51 Self {
52 api_key,
53 base_url: DEEPSEEK_API_BASE_URL,
54 http_client: Default::default(),
55 }
56 }
57}
58
59impl<'a, T> ClientBuilder<'a, T> {
60 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
61 Self {
62 api_key,
63 base_url: DEEPSEEK_API_BASE_URL,
64 http_client,
65 }
66 }
67 pub fn base_url(mut self, base_url: &'a str) -> Self {
68 self.base_url = base_url;
69 self
70 }
71
72 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
73 ClientBuilder {
74 api_key: self.api_key,
75 base_url: self.base_url,
76 http_client,
77 }
78 }
79
80 pub fn build(self) -> Client<T> {
81 Client {
82 base_url: self.base_url.to_string(),
83 api_key: self.api_key.to_string(),
84 http_client: self.http_client,
85 }
86 }
87}
88
89#[derive(Clone)]
90pub struct Client<T = reqwest::Client> {
91 pub base_url: String,
92 api_key: String,
93 http_client: T,
94}
95
96impl<T> std::fmt::Debug for Client<T>
97where
98 T: std::fmt::Debug,
99{
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("Client")
102 .field("base_url", &self.base_url)
103 .field("http_client", &self.http_client)
104 .field("api_key", &"<REDACTED>")
105 .finish()
106 }
107}
108
109impl<T> Client<T>
110where
111 T: HttpClientExt,
112{
113 fn req(
114 &self,
115 method: http_client::Method,
116 path: &str,
117 ) -> http_client::Result<http_client::Builder> {
118 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
119
120 http_client::with_bearer_auth(
121 http_client::Request::builder().method(method).uri(url),
122 &self.api_key,
123 )
124 }
125
126 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
127 self.req(http_client::Method::GET, path)
128 }
129
130 async fn send<U, R>(
131 &self,
132 req: http_client::Request<U>,
133 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
134 where
135 U: Into<Bytes> + Send,
136 R: From<Bytes> + Send + 'static,
137 {
138 self.http_client.send(req).await
139 }
140}
141
142impl Client<reqwest::Client> {
143 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
144 ClientBuilder::new(api_key)
145 }
146
147 pub fn new(api_key: &str) -> Self {
148 ClientBuilder::new(api_key).build()
149 }
150
151 pub fn from_env() -> Self {
152 <Self as ProviderClient>::from_env()
153 }
154}
155
156impl<T> ProviderClient for Client<T>
157where
158 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
159{
160 fn from_env() -> Self {
162 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
163 ClientBuilder::<T>::new(&api_key).build()
164 }
165
166 fn from_val(input: crate::client::ProviderValue) -> Self {
167 let crate::client::ProviderValue::Simple(api_key) = input else {
168 panic!("Incorrect provider value type")
169 };
170 ClientBuilder::<T>::new(&api_key).build()
171 }
172}
173
174impl<T> CompletionClient for Client<T>
175where
176 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
177{
178 type CompletionModel = CompletionModel<T>;
179
180 fn completion_model(&self, model_name: &str) -> Self::CompletionModel {
182 CompletionModel {
183 client: self.clone(),
184 model: model_name.to_string(),
185 }
186 }
187}
188
189impl<T> VerifyClient for Client<T>
190where
191 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
192{
193 #[cfg_attr(feature = "worker", worker::send)]
194 async fn verify(&self) -> Result<(), VerifyError> {
195 let req = self
196 .get("/user/balance")?
197 .body(http_client::NoBody)
198 .map_err(http_client::Error::from)?;
199
200 let response = self.send(req).await?;
201
202 match response.status() {
203 reqwest::StatusCode::OK => Ok(()),
204 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
205 reqwest::StatusCode::INTERNAL_SERVER_ERROR
206 | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
207 let text = http_client::text(response).await?;
208 Err(VerifyError::ProviderError(text))
209 }
210 _ => {
211 Ok(())
214 }
215 }
216 }
217}
218
219impl_conversion_traits!(
220 AsEmbeddings,
221 AsTranscription,
222 AsImageGeneration,
223 AsAudioGeneration for Client<T>
224);
225
226#[derive(Debug, Deserialize)]
227struct ApiErrorResponse {
228 message: String,
229}
230
231#[derive(Debug, Deserialize)]
232#[serde(untagged)]
233enum ApiResponse<T> {
234 Ok(T),
235 Err(ApiErrorResponse),
236}
237
238impl From<ApiErrorResponse> for CompletionError {
239 fn from(err: ApiErrorResponse) -> Self {
240 CompletionError::ProviderError(err.message)
241 }
242}
243
244#[derive(Clone, Debug, Serialize, Deserialize)]
246pub struct CompletionResponse {
247 pub choices: Vec<Choice>,
249 pub usage: Usage,
250 }
252
253#[derive(Clone, Debug, Serialize, Deserialize, Default)]
254pub struct Usage {
255 pub completion_tokens: u32,
256 pub prompt_tokens: u32,
257 pub prompt_cache_hit_tokens: u32,
258 pub prompt_cache_miss_tokens: u32,
259 pub total_tokens: u32,
260 #[serde(skip_serializing_if = "Option::is_none")]
261 pub completion_tokens_details: Option<CompletionTokensDetails>,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 pub prompt_tokens_details: Option<PromptTokensDetails>,
264}
265
266impl Usage {
267 fn new() -> Self {
268 Self {
269 completion_tokens: 0,
270 prompt_tokens: 0,
271 prompt_cache_hit_tokens: 0,
272 prompt_cache_miss_tokens: 0,
273 total_tokens: 0,
274 completion_tokens_details: None,
275 prompt_tokens_details: None,
276 }
277 }
278}
279
280#[derive(Clone, Debug, Serialize, Deserialize, Default)]
281pub struct CompletionTokensDetails {
282 #[serde(skip_serializing_if = "Option::is_none")]
283 pub reasoning_tokens: Option<u32>,
284}
285
286#[derive(Clone, Debug, Serialize, Deserialize, Default)]
287pub struct PromptTokensDetails {
288 #[serde(skip_serializing_if = "Option::is_none")]
289 pub cached_tokens: Option<u32>,
290}
291
292#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
293pub struct Choice {
294 pub index: usize,
295 pub message: Message,
296 pub logprobs: Option<serde_json::Value>,
297 pub finish_reason: String,
298}
299
300#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
301#[serde(tag = "role", rename_all = "lowercase")]
302pub enum Message {
303 System {
304 content: String,
305 #[serde(skip_serializing_if = "Option::is_none")]
306 name: Option<String>,
307 },
308 User {
309 content: String,
310 #[serde(skip_serializing_if = "Option::is_none")]
311 name: Option<String>,
312 },
313 Assistant {
314 content: String,
315 #[serde(skip_serializing_if = "Option::is_none")]
316 name: Option<String>,
317 #[serde(
318 default,
319 deserialize_with = "json_utils::null_or_vec",
320 skip_serializing_if = "Vec::is_empty"
321 )]
322 tool_calls: Vec<ToolCall>,
323 },
324 #[serde(rename = "tool")]
325 ToolResult {
326 tool_call_id: String,
327 content: String,
328 },
329}
330
331impl Message {
332 pub fn system(content: &str) -> Self {
333 Message::System {
334 content: content.to_owned(),
335 name: None,
336 }
337 }
338}
339
340impl From<message::ToolResult> for Message {
341 fn from(tool_result: message::ToolResult) -> Self {
342 let content = match tool_result.content.first() {
343 message::ToolResultContent::Text(text) => text.text,
344 message::ToolResultContent::Image(_) => String::from("[Image]"),
345 };
346
347 Message::ToolResult {
348 tool_call_id: tool_result.id,
349 content,
350 }
351 }
352}
353
354impl From<message::ToolCall> for ToolCall {
355 fn from(tool_call: message::ToolCall) -> Self {
356 Self {
357 id: tool_call.id,
358 index: 0,
360 r#type: ToolType::Function,
361 function: Function {
362 name: tool_call.function.name,
363 arguments: tool_call.function.arguments,
364 },
365 }
366 }
367}
368
369impl TryFrom<message::Message> for Vec<Message> {
370 type Error = message::MessageError;
371
372 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
373 match message {
374 message::Message::User { content } => {
375 let mut messages = vec![];
377
378 let tool_results = content
379 .clone()
380 .into_iter()
381 .filter_map(|content| match content {
382 message::UserContent::ToolResult(tool_result) => {
383 Some(Message::from(tool_result))
384 }
385 _ => None,
386 })
387 .collect::<Vec<_>>();
388
389 messages.extend(tool_results);
390
391 let text_messages = content
393 .into_iter()
394 .filter_map(|content| match content {
395 message::UserContent::Text(text) => Some(Message::User {
396 content: text.text,
397 name: None,
398 }),
399 message::UserContent::Document(Document {
400 data:
401 DocumentSourceKind::Base64(content)
402 | DocumentSourceKind::String(content),
403 ..
404 }) => Some(Message::User {
405 content,
406 name: None,
407 }),
408 _ => None,
409 })
410 .collect::<Vec<_>>();
411 messages.extend(text_messages);
412
413 Ok(messages)
414 }
415 message::Message::Assistant { content, .. } => {
416 let mut messages: Vec<Message> = vec![];
417
418 let text_content = content
420 .clone()
421 .into_iter()
422 .filter_map(|content| match content {
423 message::AssistantContent::Text(text) => Some(Message::Assistant {
424 content: text.text,
425 name: None,
426 tool_calls: vec![],
427 }),
428 _ => None,
429 })
430 .collect::<Vec<_>>();
431
432 messages.extend(text_content);
433
434 let tool_calls = content
436 .clone()
437 .into_iter()
438 .filter_map(|content| match content {
439 message::AssistantContent::ToolCall(tool_call) => {
440 Some(ToolCall::from(tool_call))
441 }
442 _ => None,
443 })
444 .collect::<Vec<_>>();
445
446 if !tool_calls.is_empty() {
448 messages.push(Message::Assistant {
449 content: "".to_string(),
450 name: None,
451 tool_calls,
452 });
453 }
454
455 Ok(messages)
456 }
457 }
458 }
459}
460
461#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
462pub struct ToolCall {
463 pub id: String,
464 pub index: usize,
465 #[serde(default)]
466 pub r#type: ToolType,
467 pub function: Function,
468}
469
470#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
471pub struct Function {
472 pub name: String,
473 #[serde(with = "json_utils::stringified_json")]
474 pub arguments: serde_json::Value,
475}
476
477#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
478#[serde(rename_all = "lowercase")]
479pub enum ToolType {
480 #[default]
481 Function,
482}
483
484#[derive(Clone, Debug, Deserialize, Serialize)]
485pub struct ToolDefinition {
486 pub r#type: String,
487 pub function: completion::ToolDefinition,
488}
489
490impl From<crate::completion::ToolDefinition> for ToolDefinition {
491 fn from(tool: crate::completion::ToolDefinition) -> Self {
492 Self {
493 r#type: "function".into(),
494 function: tool,
495 }
496 }
497}
498
499impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
500 type Error = CompletionError;
501
502 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
503 let choice = response.choices.first().ok_or_else(|| {
504 CompletionError::ResponseError("Response contained no choices".to_owned())
505 })?;
506 let content = match &choice.message {
507 Message::Assistant {
508 content,
509 tool_calls,
510 ..
511 } => {
512 let mut content = if content.trim().is_empty() {
513 vec![]
514 } else {
515 vec![completion::AssistantContent::text(content)]
516 };
517
518 content.extend(
519 tool_calls
520 .iter()
521 .map(|call| {
522 completion::AssistantContent::tool_call(
523 &call.id,
524 &call.function.name,
525 call.function.arguments.clone(),
526 )
527 })
528 .collect::<Vec<_>>(),
529 );
530 Ok(content)
531 }
532 _ => Err(CompletionError::ResponseError(
533 "Response did not contain a valid message or tool call".into(),
534 )),
535 }?;
536
537 let choice = OneOrMany::many(content).map_err(|_| {
538 CompletionError::ResponseError(
539 "Response contained no message or tool call (empty)".to_owned(),
540 )
541 })?;
542
543 let usage = completion::Usage {
544 input_tokens: response.usage.prompt_tokens as u64,
545 output_tokens: response.usage.completion_tokens as u64,
546 total_tokens: response.usage.total_tokens as u64,
547 };
548
549 Ok(completion::CompletionResponse {
550 choice,
551 usage,
552 raw_response: response,
553 })
554 }
555}
556
557#[derive(Clone)]
559pub struct CompletionModel<T = reqwest::Client> {
560 pub client: Client<T>,
561 pub model: String,
562}
563
564impl<T> CompletionModel<T> {
565 fn create_completion_request(
566 &self,
567 completion_request: CompletionRequest,
568 ) -> Result<serde_json::Value, CompletionError> {
569 let mut partial_history = vec![];
571
572 if let Some(docs) = completion_request.normalized_documents() {
573 partial_history.push(docs);
574 }
575
576 partial_history.extend(completion_request.chat_history);
577
578 let mut full_history: Vec<Message> = completion_request
580 .preamble
581 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
582
583 full_history.extend(
585 partial_history
586 .into_iter()
587 .map(message::Message::try_into)
588 .collect::<Result<Vec<Vec<Message>>, _>>()?
589 .into_iter()
590 .flatten()
591 .collect::<Vec<_>>(),
592 );
593
594 let tool_choice = completion_request
595 .tool_choice
596 .map(crate::providers::openrouter::ToolChoice::try_from)
597 .transpose()?;
598
599 let request = if completion_request.tools.is_empty() {
600 json!({
601 "model": self.model,
602 "messages": full_history,
603 "temperature": completion_request.temperature,
604 })
605 } else {
606 json!({
607 "model": self.model,
608 "messages": full_history,
609 "temperature": completion_request.temperature,
610 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
611 "tool_choice": tool_choice,
612 })
613 };
614
615 let request = if let Some(params) = completion_request.additional_params {
616 json_utils::merge(request, params)
617 } else {
618 request
619 };
620
621 Ok(request)
622 }
623}
624
625impl<T> completion::CompletionModel for CompletionModel<T>
626where
627 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
628{
629 type Response = CompletionResponse;
630 type StreamingResponse = StreamingCompletionResponse;
631
632 #[cfg_attr(feature = "worker", worker::send)]
633 async fn completion(
634 &self,
635 completion_request: CompletionRequest,
636 ) -> Result<
637 completion::CompletionResponse<CompletionResponse>,
638 crate::completion::CompletionError,
639 > {
640 let preamble = completion_request.preamble.clone();
641 let request = self.create_completion_request(completion_request)?;
642
643 let span = if tracing::Span::current().is_disabled() {
644 info_span!(
645 target: "rig::completions",
646 "chat",
647 gen_ai.operation.name = "chat",
648 gen_ai.provider.name = "deepseek",
649 gen_ai.request.model = self.model,
650 gen_ai.system_instructions = preamble,
651 gen_ai.response.id = tracing::field::Empty,
652 gen_ai.response.model = tracing::field::Empty,
653 gen_ai.usage.output_tokens = tracing::field::Empty,
654 gen_ai.usage.input_tokens = tracing::field::Empty,
655 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
656 gen_ai.output.messages = tracing::field::Empty,
657 )
658 } else {
659 tracing::Span::current()
660 };
661
662 tracing::debug!("DeepSeek completion request: {request:?}");
663
664 let body = serde_json::to_vec(&request)?;
665 let req = self
666 .client
667 .req(Method::POST, "/chat/completions")?
668 .header("Content-Type", "application/json")
669 .body(body)
670 .map_err(|e| CompletionError::HttpError(e.into()))?;
671
672 async move {
673 let response = self.client.http_client.send::<_, Bytes>(req).await?;
674 let status = response.status();
675 let response_body = response.into_body().into_future().await?.to_vec();
676
677 if status.is_success() {
678 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
679 ApiResponse::Ok(response) => {
680 let span = tracing::Span::current();
681 span.record(
682 "gen_ai.output.messages",
683 serde_json::to_string(&response.choices).unwrap(),
684 );
685 span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens);
686 span.record(
687 "gen_ai.usage.output_tokens",
688 response.usage.completion_tokens,
689 );
690 tracing::debug!(target: "rig", "DeepSeek completion output: {}", serde_json::to_string_pretty(&response_body)?);
691
692 response.try_into()
693 }
694 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
695 }
696 } else {
697 Err(CompletionError::ProviderError(
698 String::from_utf8_lossy(&response_body).to_string()
699 ))
700 }
701 }
702 .instrument(span)
703 .await
704 }
705
706 #[cfg_attr(feature = "worker", worker::send)]
707 async fn stream(
708 &self,
709 completion_request: CompletionRequest,
710 ) -> Result<
711 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
712 CompletionError,
713 > {
714 let preamble = completion_request.preamble.clone();
715 let mut request = self.create_completion_request(completion_request)?;
716
717 request = merge(
718 request,
719 json!({"stream": true, "stream_options": {"include_usage": true}}),
720 );
721
722 let body = serde_json::to_vec(&request)?;
723
724 let req = self
725 .client
726 .req(Method::POST, "/chat/completions")?
727 .header("Content-Type", "application/json")
728 .body(body)
729 .map_err(|e| CompletionError::HttpError(e.into()))?;
730
731 let span = if tracing::Span::current().is_disabled() {
732 info_span!(
733 target: "rig::completions",
734 "chat_streaming",
735 gen_ai.operation.name = "chat_streaming",
736 gen_ai.provider.name = "deepseek",
737 gen_ai.request.model = self.model,
738 gen_ai.system_instructions = preamble,
739 gen_ai.response.id = tracing::field::Empty,
740 gen_ai.response.model = tracing::field::Empty,
741 gen_ai.usage.output_tokens = tracing::field::Empty,
742 gen_ai.usage.input_tokens = tracing::field::Empty,
743 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
744 gen_ai.output.messages = tracing::field::Empty,
745 )
746 } else {
747 tracing::Span::current()
748 };
749
750 tracing::Instrument::instrument(
751 send_compatible_streaming_request(self.client.http_client.clone(), req),
752 span,
753 )
754 .await
755 }
756}
757
758#[derive(Deserialize, Debug)]
759pub struct StreamingDelta {
760 #[serde(default)]
761 content: Option<String>,
762 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
763 tool_calls: Vec<StreamingToolCall>,
764 reasoning_content: Option<String>,
765}
766
767#[derive(Deserialize, Debug)]
768struct StreamingChoice {
769 delta: StreamingDelta,
770}
771
772#[derive(Deserialize, Debug)]
773struct StreamingCompletionChunk {
774 choices: Vec<StreamingChoice>,
775 usage: Option<Usage>,
776}
777
778#[derive(Clone, Deserialize, Serialize, Debug)]
779pub struct StreamingCompletionResponse {
780 pub usage: Usage,
781}
782
783impl GetTokenUsage for StreamingCompletionResponse {
784 fn token_usage(&self) -> Option<crate::completion::Usage> {
785 let mut usage = crate::completion::Usage::new();
786 usage.input_tokens = self.usage.prompt_tokens as u64;
787 usage.output_tokens = self.usage.completion_tokens as u64;
788 usage.total_tokens = self.usage.total_tokens as u64;
789
790 Some(usage)
791 }
792}
793
794pub async fn send_compatible_streaming_request<T>(
795 http_client: T,
796 req: Request<Vec<u8>>,
797) -> Result<
798 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
799 CompletionError,
800>
801where
802 T: HttpClientExt + Clone + 'static,
803{
804 let span = tracing::Span::current();
805 let mut event_source = GenericEventSource::new(http_client, req);
806
807 let stream = stream! {
808 let mut final_usage = Usage::new();
809 let mut text_response = String::new();
810 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
811
812 while let Some(event_result) = event_source.next().await {
813 match event_result {
814 Ok(Event::Open) => {
815 tracing::trace!("SSE connection opened");
816 continue;
817 }
818 Ok(Event::Message(message)) => {
819 if message.data.trim().is_empty() || message.data == "[DONE]" {
820 continue;
821 }
822
823 let parsed = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
824 let Ok(data) = parsed else {
825 let err = parsed.unwrap_err();
826 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
827 continue;
828 };
829
830 if let Some(choice) = data.choices.first() {
831 let delta = &choice.delta;
832
833 if !delta.tool_calls.is_empty() {
834 for tool_call in &delta.tool_calls {
835 let function = &tool_call.function;
836
837 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
839 && function.arguments.is_empty()
840 {
841 let id = tool_call.id.clone().unwrap_or_default();
842 let name = function.name.clone().unwrap();
843 calls.insert(tool_call.index, (id, name, String::new()));
844 }
845 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
847 && !function.arguments.is_empty()
848 {
849 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
850 let combined = format!("{}{}", existing_args, function.arguments);
851 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
852 } else {
853 tracing::debug!("Partial tool call received but tool call was never started.");
854 }
855 }
856 else {
858 let id = tool_call.id.clone().unwrap_or_default();
859 let name = function.name.clone().unwrap_or_default();
860 let arguments_str = function.arguments.clone();
861
862 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
863 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
864 continue;
865 };
866
867 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
868 id,
869 name,
870 arguments: arguments_json,
871 call_id: None,
872 });
873 }
874 }
875 }
876
877 if let Some(content) = &delta.reasoning_content {
879 yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
880 reasoning: content.to_string(),
881 id: None,
882 signature: None,
883 });
884 }
885
886 if let Some(content) = &delta.content {
887 text_response += content;
888 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
889 }
890 }
891
892 if let Some(usage) = data.usage {
893 final_usage = usage.clone();
894 }
895 }
896 Err(crate::http_client::Error::StreamEnded) => {
897 break;
898 }
899 Err(err) => {
900 tracing::error!(?err, "SSE error");
901 yield Err(CompletionError::ResponseError(err.to_string()));
902 break;
903 }
904 }
905 }
906
907 event_source.close();
908
909 let mut tool_calls = Vec::new();
910 for (index, (id, name, arguments)) in calls {
912 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
913 continue;
914 };
915
916 tool_calls.push(ToolCall {
917 id: id.clone(),
918 index,
919 r#type: ToolType::Function,
920 function: Function {
921 name: name.clone(),
922 arguments: arguments_json.clone()
923 }
924 });
925 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
926 id,
927 name,
928 arguments: arguments_json,
929 call_id: None,
930 });
931 }
932
933 let message = Message::Assistant {
934 content: text_response,
935 name: None,
936 tool_calls
937 };
938
939 span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap());
940
941 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
942 StreamingCompletionResponse { usage: final_usage.clone() }
943 ));
944 };
945
946 Ok(crate::streaming::StreamingCompletionResponse::stream(
947 Box::pin(stream),
948 ))
949}
950
951pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
957pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
959
960#[cfg(test)]
962mod tests {
963
964 use super::*;
965
966 #[test]
967 fn test_deserialize_vec_choice() {
968 let data = r#"[{
969 "finish_reason": "stop",
970 "index": 0,
971 "logprobs": null,
972 "message":{"role":"assistant","content":"Hello, world!"}
973 }]"#;
974
975 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
976 assert_eq!(choices.len(), 1);
977 match &choices.first().unwrap().message {
978 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
979 _ => panic!("Expected assistant message"),
980 }
981 }
982
983 #[test]
984 fn test_deserialize_deepseek_response() {
985 let data = r#"{
986 "choices":[{
987 "finish_reason": "stop",
988 "index": 0,
989 "logprobs": null,
990 "message":{"role":"assistant","content":"Hello, world!"}
991 }],
992 "usage": {
993 "completion_tokens": 0,
994 "prompt_tokens": 0,
995 "prompt_cache_hit_tokens": 0,
996 "prompt_cache_miss_tokens": 0,
997 "total_tokens": 0
998 }
999 }"#;
1000
1001 let jd = &mut serde_json::Deserializer::from_str(data);
1002 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
1003 match result {
1004 Ok(response) => match &response.choices.first().unwrap().message {
1005 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
1006 _ => panic!("Expected assistant message"),
1007 },
1008 Err(err) => {
1009 panic!("Deserialization error at {}: {}", err.path(), err);
1010 }
1011 }
1012 }
1013
1014 #[test]
1015 fn test_deserialize_example_response() {
1016 let data = r#"
1017 {
1018 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
1019 "object": "chat.completion",
1020 "created": 0,
1021 "model": "deepseek-chat",
1022 "choices": [
1023 {
1024 "index": 0,
1025 "message": {
1026 "role": "assistant",
1027 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
1028 },
1029 "logprobs": null,
1030 "finish_reason": "stop"
1031 }
1032 ],
1033 "usage": {
1034 "prompt_tokens": 13,
1035 "completion_tokens": 32,
1036 "total_tokens": 45,
1037 "prompt_tokens_details": {
1038 "cached_tokens": 0
1039 },
1040 "prompt_cache_hit_tokens": 0,
1041 "prompt_cache_miss_tokens": 13
1042 },
1043 "system_fingerprint": "fp_4b6881f2c5"
1044 }
1045 "#;
1046 let jd = &mut serde_json::Deserializer::from_str(data);
1047 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
1048
1049 match result {
1050 Ok(response) => match &response.choices.first().unwrap().message {
1051 Message::Assistant { content, .. } => assert_eq!(
1052 content,
1053 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
1054 ),
1055 _ => panic!("Expected assistant message"),
1056 },
1057 Err(err) => {
1058 panic!("Deserialization error at {}: {}", err.path(), err);
1059 }
1060 }
1061 }
1062
1063 #[test]
1064 fn test_serialize_deserialize_tool_call_message() {
1065 let tool_call_choice_json = r#"
1066 {
1067 "finish_reason": "tool_calls",
1068 "index": 0,
1069 "logprobs": null,
1070 "message": {
1071 "content": "",
1072 "role": "assistant",
1073 "tool_calls": [
1074 {
1075 "function": {
1076 "arguments": "{\"x\":2,\"y\":5}",
1077 "name": "subtract"
1078 },
1079 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
1080 "index": 0,
1081 "type": "function"
1082 }
1083 ]
1084 }
1085 }
1086 "#;
1087
1088 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
1089
1090 let expected_choice: Choice = Choice {
1091 finish_reason: "tool_calls".to_string(),
1092 index: 0,
1093 logprobs: None,
1094 message: Message::Assistant {
1095 content: "".to_string(),
1096 name: None,
1097 tool_calls: vec![ToolCall {
1098 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
1099 function: Function {
1100 name: "subtract".to_string(),
1101 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
1102 },
1103 index: 0,
1104 r#type: ToolType::Function,
1105 }],
1106 },
1107 };
1108
1109 assert_eq!(choice, expected_choice);
1110 }
1111}