1use crate::client::{CompletionClient, ProviderClient};
13use crate::json_utils::merge;
14use crate::message::Document;
15use crate::providers::openai;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 OneOrMany,
20 completion::{self, CompletionError, CompletionModel, CompletionRequest},
21 impl_conversion_traits, json_utils, message,
22};
23use reqwest::Client as HttpClient;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26
27const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
31
32#[derive(Clone)]
33pub struct Client {
34 pub base_url: String,
35 api_key: String,
36 http_client: HttpClient,
37}
38
39impl std::fmt::Debug for Client {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("Client")
42 .field("base_url", &self.base_url)
43 .field("http_client", &self.http_client)
44 .field("api_key", &"<REDACTED>")
45 .finish()
46 }
47}
48
49impl Client {
50 pub fn new(api_key: &str) -> Self {
52 Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
53 }
54
55 pub fn from_url(api_key: &str, base_url: &str) -> Self {
58 Self {
59 base_url: base_url.to_string(),
60 api_key: api_key.to_string(),
61 http_client: reqwest::Client::builder()
62 .build()
63 .expect("DeepSeek reqwest client should build"),
64 }
65 }
66
67 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
70 self.http_client = client;
71
72 self
73 }
74
75 fn post(&self, path: &str) -> reqwest::RequestBuilder {
76 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
77 self.http_client.post(url).bearer_auth(&self.api_key)
78 }
79}
80
81impl ProviderClient for Client {
82 fn from_env() -> Self {
84 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
85 Self::new(&api_key)
86 }
87
88 fn from_val(input: crate::client::ProviderValue) -> Self {
89 let crate::client::ProviderValue::Simple(api_key) = input else {
90 panic!("Incorrect provider value type")
91 };
92 Self::new(&api_key)
93 }
94}
95
96impl CompletionClient for Client {
97 type CompletionModel = DeepSeekCompletionModel;
98
99 fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
101 DeepSeekCompletionModel {
102 client: self.clone(),
103 model: model_name.to_string(),
104 }
105 }
106}
107
108impl_conversion_traits!(
109 AsEmbeddings,
110 AsTranscription,
111 AsImageGeneration,
112 AsAudioGeneration for Client
113);
114
115#[derive(Debug, Deserialize)]
116struct ApiErrorResponse {
117 message: String,
118}
119
120#[derive(Debug, Deserialize)]
121#[serde(untagged)]
122enum ApiResponse<T> {
123 Ok(T),
124 Err(ApiErrorResponse),
125}
126
127impl From<ApiErrorResponse> for CompletionError {
128 fn from(err: ApiErrorResponse) -> Self {
129 CompletionError::ProviderError(err.message)
130 }
131}
132
133#[derive(Clone, Debug, Serialize, Deserialize)]
135pub struct CompletionResponse {
136 pub choices: Vec<Choice>,
138 pub usage: Usage,
139 }
141
142#[derive(Clone, Debug, Serialize, Deserialize, Default)]
143pub struct Usage {
144 pub completion_tokens: u32,
145 pub prompt_tokens: u32,
146 pub prompt_cache_hit_tokens: u32,
147 pub prompt_cache_miss_tokens: u32,
148 pub total_tokens: u32,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 pub completion_tokens_details: Option<CompletionTokensDetails>,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 pub prompt_tokens_details: Option<PromptTokensDetails>,
153}
154
155#[derive(Clone, Debug, Serialize, Deserialize, Default)]
156pub struct CompletionTokensDetails {
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pub reasoning_tokens: Option<u32>,
159}
160
161#[derive(Clone, Debug, Serialize, Deserialize, Default)]
162pub struct PromptTokensDetails {
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub cached_tokens: Option<u32>,
165}
166
167#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
168pub struct Choice {
169 pub index: usize,
170 pub message: Message,
171 pub logprobs: Option<serde_json::Value>,
172 pub finish_reason: String,
173}
174
175#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(tag = "role", rename_all = "lowercase")]
177pub enum Message {
178 System {
179 content: String,
180 #[serde(skip_serializing_if = "Option::is_none")]
181 name: Option<String>,
182 },
183 User {
184 content: String,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 name: Option<String>,
187 },
188 Assistant {
189 content: String,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 name: Option<String>,
192 #[serde(
193 default,
194 deserialize_with = "json_utils::null_or_vec",
195 skip_serializing_if = "Vec::is_empty"
196 )]
197 tool_calls: Vec<ToolCall>,
198 },
199 #[serde(rename = "tool")]
200 ToolResult {
201 tool_call_id: String,
202 content: String,
203 },
204}
205
206impl Message {
207 pub fn system(content: &str) -> Self {
208 Message::System {
209 content: content.to_owned(),
210 name: None,
211 }
212 }
213}
214
215impl From<message::ToolResult> for Message {
216 fn from(tool_result: message::ToolResult) -> Self {
217 let content = match tool_result.content.first() {
218 message::ToolResultContent::Text(text) => text.text,
219 message::ToolResultContent::Image(_) => String::from("[Image]"),
220 };
221
222 Message::ToolResult {
223 tool_call_id: tool_result.id,
224 content,
225 }
226 }
227}
228
229impl From<message::ToolCall> for ToolCall {
230 fn from(tool_call: message::ToolCall) -> Self {
231 Self {
232 id: tool_call.id,
233 index: 0,
235 r#type: ToolType::Function,
236 function: Function {
237 name: tool_call.function.name,
238 arguments: tool_call.function.arguments,
239 },
240 }
241 }
242}
243
244impl TryFrom<message::Message> for Vec<Message> {
245 type Error = message::MessageError;
246
247 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248 match message {
249 message::Message::User { content } => {
250 let mut messages = vec![];
252
253 let tool_results = content
254 .clone()
255 .into_iter()
256 .filter_map(|content| match content {
257 message::UserContent::ToolResult(tool_result) => {
258 Some(Message::from(tool_result))
259 }
260 _ => None,
261 })
262 .collect::<Vec<_>>();
263
264 messages.extend(tool_results);
265
266 let text_messages = content
268 .into_iter()
269 .filter_map(|content| match content {
270 message::UserContent::Text(text) => Some(Message::User {
271 content: text.text,
272 name: None,
273 }),
274 message::UserContent::Document(Document { data, .. }) => {
275 Some(Message::User {
276 content: data,
277 name: None,
278 })
279 }
280 _ => None,
281 })
282 .collect::<Vec<_>>();
283 messages.extend(text_messages);
284
285 Ok(messages)
286 }
287 message::Message::Assistant { content, .. } => {
288 let mut messages: Vec<Message> = vec![];
289
290 let tool_calls = content
292 .clone()
293 .into_iter()
294 .filter_map(|content| match content {
295 message::AssistantContent::ToolCall(tool_call) => {
296 Some(ToolCall::from(tool_call))
297 }
298 _ => None,
299 })
300 .collect::<Vec<_>>();
301
302 if !tool_calls.is_empty() {
304 messages.push(Message::Assistant {
305 content: "".to_string(),
306 name: None,
307 tool_calls,
308 });
309 }
310
311 let text_content = content
313 .into_iter()
314 .filter_map(|content| match content {
315 message::AssistantContent::Text(text) => Some(Message::Assistant {
316 content: text.text,
317 name: None,
318 tool_calls: vec![],
319 }),
320 _ => None,
321 })
322 .collect::<Vec<_>>();
323
324 messages.extend(text_content);
325
326 Ok(messages)
327 }
328 }
329 }
330}
331
332#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
333pub struct ToolCall {
334 pub id: String,
335 pub index: usize,
336 #[serde(default)]
337 pub r#type: ToolType,
338 pub function: Function,
339}
340
341#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
342pub struct Function {
343 pub name: String,
344 #[serde(with = "json_utils::stringified_json")]
345 pub arguments: serde_json::Value,
346}
347
348#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
349#[serde(rename_all = "lowercase")]
350pub enum ToolType {
351 #[default]
352 Function,
353}
354
355#[derive(Clone, Debug, Deserialize, Serialize)]
356pub struct ToolDefinition {
357 pub r#type: String,
358 pub function: completion::ToolDefinition,
359}
360
361impl From<crate::completion::ToolDefinition> for ToolDefinition {
362 fn from(tool: crate::completion::ToolDefinition) -> Self {
363 Self {
364 r#type: "function".into(),
365 function: tool,
366 }
367 }
368}
369
370impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
371 type Error = CompletionError;
372
373 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
374 let choice = response.choices.first().ok_or_else(|| {
375 CompletionError::ResponseError("Response contained no choices".to_owned())
376 })?;
377 let content = match &choice.message {
378 Message::Assistant {
379 content,
380 tool_calls,
381 ..
382 } => {
383 let mut content = if content.trim().is_empty() {
384 vec![]
385 } else {
386 vec![completion::AssistantContent::text(content)]
387 };
388
389 content.extend(
390 tool_calls
391 .iter()
392 .map(|call| {
393 completion::AssistantContent::tool_call(
394 &call.id,
395 &call.function.name,
396 call.function.arguments.clone(),
397 )
398 })
399 .collect::<Vec<_>>(),
400 );
401 Ok(content)
402 }
403 _ => Err(CompletionError::ResponseError(
404 "Response did not contain a valid message or tool call".into(),
405 )),
406 }?;
407
408 let choice = OneOrMany::many(content).map_err(|_| {
409 CompletionError::ResponseError(
410 "Response contained no message or tool call (empty)".to_owned(),
411 )
412 })?;
413
414 let usage = completion::Usage {
415 input_tokens: response.usage.prompt_tokens as u64,
416 output_tokens: response.usage.completion_tokens as u64,
417 total_tokens: response.usage.total_tokens as u64,
418 };
419
420 Ok(completion::CompletionResponse {
421 choice,
422 usage,
423 raw_response: response,
424 })
425 }
426}
427
428#[derive(Clone)]
430pub struct DeepSeekCompletionModel {
431 pub client: Client,
432 pub model: String,
433}
434
435impl DeepSeekCompletionModel {
436 fn create_completion_request(
437 &self,
438 completion_request: CompletionRequest,
439 ) -> Result<serde_json::Value, CompletionError> {
440 let mut partial_history = vec![];
442
443 if let Some(docs) = completion_request.normalized_documents() {
444 partial_history.push(docs);
445 }
446
447 partial_history.extend(completion_request.chat_history);
448
449 let mut full_history: Vec<Message> = completion_request
451 .preamble
452 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
453
454 full_history.extend(
456 partial_history
457 .into_iter()
458 .map(message::Message::try_into)
459 .collect::<Result<Vec<Vec<Message>>, _>>()?
460 .into_iter()
461 .flatten()
462 .collect::<Vec<_>>(),
463 );
464
465 let request = if completion_request.tools.is_empty() {
466 json!({
467 "model": self.model,
468 "messages": full_history,
469 "temperature": completion_request.temperature,
470 })
471 } else {
472 json!({
473 "model": self.model,
474 "messages": full_history,
475 "temperature": completion_request.temperature,
476 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
477 "tool_choice": "auto",
478 })
479 };
480
481 let request = if let Some(params) = completion_request.additional_params {
482 json_utils::merge(request, params)
483 } else {
484 request
485 };
486
487 Ok(request)
488 }
489}
490
491impl CompletionModel for DeepSeekCompletionModel {
492 type Response = CompletionResponse;
493 type StreamingResponse = openai::StreamingCompletionResponse;
494
495 #[cfg_attr(feature = "worker", worker::send)]
496 async fn completion(
497 &self,
498 completion_request: CompletionRequest,
499 ) -> Result<
500 completion::CompletionResponse<CompletionResponse>,
501 crate::completion::CompletionError,
502 > {
503 let request = self.create_completion_request(completion_request)?;
504
505 tracing::debug!("DeepSeek completion request: {request:?}");
506
507 let response = self
508 .client
509 .post("/chat/completions")
510 .json(&request)
511 .send()
512 .await?;
513
514 if response.status().is_success() {
515 let t = response.text().await?;
516 tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
517
518 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
519 ApiResponse::Ok(response) => response.try_into(),
520 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
521 }
522 } else {
523 Err(CompletionError::ProviderError(response.text().await?))
524 }
525 }
526
527 #[cfg_attr(feature = "worker", worker::send)]
528 async fn stream(
529 &self,
530 completion_request: CompletionRequest,
531 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
532 let mut request = self.create_completion_request(completion_request)?;
533
534 request = merge(
535 request,
536 json!({"stream": true, "stream_options": {"include_usage": true}}),
537 );
538
539 let builder = self.client.post("/v1/chat/completions").json(&request);
540 send_compatible_streaming_request(builder).await
541 }
542}
543
544pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
550pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
552
553#[cfg(test)]
555mod tests {
556
557 use super::*;
558
559 #[test]
560 fn test_deserialize_vec_choice() {
561 let data = r#"[{
562 "finish_reason": "stop",
563 "index": 0,
564 "logprobs": null,
565 "message":{"role":"assistant","content":"Hello, world!"}
566 }]"#;
567
568 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
569 assert_eq!(choices.len(), 1);
570 match &choices.first().unwrap().message {
571 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
572 _ => panic!("Expected assistant message"),
573 }
574 }
575
576 #[test]
577 fn test_deserialize_deepseek_response() {
578 let data = r#"{
579 "choices":[{
580 "finish_reason": "stop",
581 "index": 0,
582 "logprobs": null,
583 "message":{"role":"assistant","content":"Hello, world!"}
584 }],
585 "usage": {
586 "completion_tokens": 0,
587 "prompt_tokens": 0,
588 "prompt_cache_hit_tokens": 0,
589 "prompt_cache_miss_tokens": 0,
590 "total_tokens": 0
591 }
592 }"#;
593
594 let jd = &mut serde_json::Deserializer::from_str(data);
595 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
596 match result {
597 Ok(response) => match &response.choices.first().unwrap().message {
598 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
599 _ => panic!("Expected assistant message"),
600 },
601 Err(err) => {
602 panic!("Deserialization error at {}: {}", err.path(), err);
603 }
604 }
605 }
606
607 #[test]
608 fn test_deserialize_example_response() {
609 let data = r#"
610 {
611 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
612 "object": "chat.completion",
613 "created": 0,
614 "model": "deepseek-chat",
615 "choices": [
616 {
617 "index": 0,
618 "message": {
619 "role": "assistant",
620 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
621 },
622 "logprobs": null,
623 "finish_reason": "stop"
624 }
625 ],
626 "usage": {
627 "prompt_tokens": 13,
628 "completion_tokens": 32,
629 "total_tokens": 45,
630 "prompt_tokens_details": {
631 "cached_tokens": 0
632 },
633 "prompt_cache_hit_tokens": 0,
634 "prompt_cache_miss_tokens": 13
635 },
636 "system_fingerprint": "fp_4b6881f2c5"
637 }
638 "#;
639 let jd = &mut serde_json::Deserializer::from_str(data);
640 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
641
642 match result {
643 Ok(response) => match &response.choices.first().unwrap().message {
644 Message::Assistant { content, .. } => assert_eq!(
645 content,
646 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
647 ),
648 _ => panic!("Expected assistant message"),
649 },
650 Err(err) => {
651 panic!("Deserialization error at {}: {}", err.path(), err);
652 }
653 }
654 }
655
656 #[test]
657 fn test_serialize_deserialize_tool_call_message() {
658 let tool_call_choice_json = r#"
659 {
660 "finish_reason": "tool_calls",
661 "index": 0,
662 "logprobs": null,
663 "message": {
664 "content": "",
665 "role": "assistant",
666 "tool_calls": [
667 {
668 "function": {
669 "arguments": "{\"x\":2,\"y\":5}",
670 "name": "subtract"
671 },
672 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
673 "index": 0,
674 "type": "function"
675 }
676 ]
677 }
678 }
679 "#;
680
681 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
682
683 let expected_choice: Choice = Choice {
684 finish_reason: "tool_calls".to_string(),
685 index: 0,
686 logprobs: None,
687 message: Message::Assistant {
688 content: "".to_string(),
689 name: None,
690 tool_calls: vec![ToolCall {
691 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
692 function: Function {
693 name: "subtract".to_string(),
694 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
695 },
696 index: 0,
697 r#type: ToolType::Function,
698 }],
699 },
700 };
701
702 assert_eq!(choice, expected_choice);
703 }
704}