1use crate::client::{CompletionClient, ProviderClient};
13use crate::json_utils::merge;
14use crate::message::Document;
15use crate::providers::openai;
16use crate::providers::openai::send_compatible_streaming_request;
17use crate::streaming::StreamingCompletionResponse;
18use crate::{
19 OneOrMany,
20 completion::{self, CompletionError, CompletionModel, CompletionRequest},
21 impl_conversion_traits, json_utils, message,
22};
23use reqwest::Client as HttpClient;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26
27const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
31
32#[derive(Clone)]
33pub struct Client {
34 pub base_url: String,
35 api_key: String,
36 http_client: HttpClient,
37}
38
39impl std::fmt::Debug for Client {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("Client")
42 .field("base_url", &self.base_url)
43 .field("http_client", &self.http_client)
44 .field("api_key", &"<REDACTED>")
45 .finish()
46 }
47}
48
49impl Client {
50 pub fn new(api_key: &str) -> Self {
52 Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
53 }
54
55 pub fn from_url(api_key: &str, base_url: &str) -> Self {
58 Self {
59 base_url: base_url.to_string(),
60 api_key: api_key.to_string(),
61 http_client: reqwest::Client::builder()
62 .build()
63 .expect("DeepSeek reqwest client should build"),
64 }
65 }
66
67 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
70 self.http_client = client;
71
72 self
73 }
74
75 fn post(&self, path: &str) -> reqwest::RequestBuilder {
76 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
77 self.http_client.post(url).bearer_auth(&self.api_key)
78 }
79}
80
81impl ProviderClient for Client {
82 fn from_env() -> Self {
84 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
85 Self::new(&api_key)
86 }
87}
88
89impl CompletionClient for Client {
90 type CompletionModel = DeepSeekCompletionModel;
91
92 fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
94 DeepSeekCompletionModel {
95 client: self.clone(),
96 model: model_name.to_string(),
97 }
98 }
99}
100
101impl_conversion_traits!(
102 AsEmbeddings,
103 AsTranscription,
104 AsImageGeneration,
105 AsAudioGeneration for Client
106);
107
108#[derive(Debug, Deserialize)]
109struct ApiErrorResponse {
110 message: String,
111}
112
113#[derive(Debug, Deserialize)]
114#[serde(untagged)]
115enum ApiResponse<T> {
116 Ok(T),
117 Err(ApiErrorResponse),
118}
119
120impl From<ApiErrorResponse> for CompletionError {
121 fn from(err: ApiErrorResponse) -> Self {
122 CompletionError::ProviderError(err.message)
123 }
124}
125
126#[derive(Clone, Debug, Serialize, Deserialize)]
128pub struct CompletionResponse {
129 pub choices: Vec<Choice>,
131 }
133
134#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
135pub struct Choice {
136 pub index: usize,
137 pub message: Message,
138 pub logprobs: Option<serde_json::Value>,
139 pub finish_reason: String,
140}
141
142#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
143#[serde(tag = "role", rename_all = "lowercase")]
144pub enum Message {
145 System {
146 content: String,
147 #[serde(skip_serializing_if = "Option::is_none")]
148 name: Option<String>,
149 },
150 User {
151 content: String,
152 #[serde(skip_serializing_if = "Option::is_none")]
153 name: Option<String>,
154 },
155 Assistant {
156 content: String,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 name: Option<String>,
159 #[serde(
160 default,
161 deserialize_with = "json_utils::null_or_vec",
162 skip_serializing_if = "Vec::is_empty"
163 )]
164 tool_calls: Vec<ToolCall>,
165 },
166 #[serde(rename = "tool")]
167 ToolResult {
168 tool_call_id: String,
169 content: String,
170 },
171}
172
173impl Message {
174 pub fn system(content: &str) -> Self {
175 Message::System {
176 content: content.to_owned(),
177 name: None,
178 }
179 }
180}
181
182impl From<message::ToolResult> for Message {
183 fn from(tool_result: message::ToolResult) -> Self {
184 let content = match tool_result.content.first() {
185 message::ToolResultContent::Text(text) => text.text,
186 message::ToolResultContent::Image(_) => String::from("[Image]"),
187 };
188
189 Message::ToolResult {
190 tool_call_id: tool_result.id,
191 content,
192 }
193 }
194}
195
196impl From<message::ToolCall> for ToolCall {
197 fn from(tool_call: message::ToolCall) -> Self {
198 Self {
199 id: tool_call.id,
200 index: 0,
202 r#type: ToolType::Function,
203 function: Function {
204 name: tool_call.function.name,
205 arguments: tool_call.function.arguments,
206 },
207 }
208 }
209}
210
211impl TryFrom<message::Message> for Vec<Message> {
212 type Error = message::MessageError;
213
214 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
215 match message {
216 message::Message::User { content } => {
217 let mut messages = vec![];
219
220 let tool_results = content
221 .clone()
222 .into_iter()
223 .filter_map(|content| match content {
224 message::UserContent::ToolResult(tool_result) => {
225 Some(Message::from(tool_result))
226 }
227 _ => None,
228 })
229 .collect::<Vec<_>>();
230
231 messages.extend(tool_results);
232
233 let text_messages = content
235 .into_iter()
236 .filter_map(|content| match content {
237 message::UserContent::Text(text) => Some(Message::User {
238 content: text.text,
239 name: None,
240 }),
241 message::UserContent::Document(Document { data, .. }) => {
242 Some(Message::User {
243 content: data,
244 name: None,
245 })
246 }
247 _ => None,
248 })
249 .collect::<Vec<_>>();
250 messages.extend(text_messages);
251
252 Ok(messages)
253 }
254 message::Message::Assistant { content, .. } => {
255 let mut messages: Vec<Message> = vec![];
256
257 let tool_calls = content
259 .clone()
260 .into_iter()
261 .filter_map(|content| match content {
262 message::AssistantContent::ToolCall(tool_call) => {
263 Some(ToolCall::from(tool_call))
264 }
265 _ => None,
266 })
267 .collect::<Vec<_>>();
268
269 if !tool_calls.is_empty() {
271 messages.push(Message::Assistant {
272 content: "".to_string(),
273 name: None,
274 tool_calls,
275 });
276 }
277
278 let text_content = content
280 .into_iter()
281 .filter_map(|content| match content {
282 message::AssistantContent::Text(text) => Some(Message::Assistant {
283 content: text.text,
284 name: None,
285 tool_calls: vec![],
286 }),
287 _ => None,
288 })
289 .collect::<Vec<_>>();
290
291 messages.extend(text_content);
292
293 Ok(messages)
294 }
295 }
296 }
297}
298
299#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
300pub struct ToolCall {
301 pub id: String,
302 pub index: usize,
303 #[serde(default)]
304 pub r#type: ToolType,
305 pub function: Function,
306}
307
308#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
309pub struct Function {
310 pub name: String,
311 #[serde(with = "json_utils::stringified_json")]
312 pub arguments: serde_json::Value,
313}
314
315#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
316#[serde(rename_all = "lowercase")]
317pub enum ToolType {
318 #[default]
319 Function,
320}
321
322#[derive(Clone, Debug, Deserialize, Serialize)]
323pub struct ToolDefinition {
324 pub r#type: String,
325 pub function: completion::ToolDefinition,
326}
327
328impl From<crate::completion::ToolDefinition> for ToolDefinition {
329 fn from(tool: crate::completion::ToolDefinition) -> Self {
330 Self {
331 r#type: "function".into(),
332 function: tool,
333 }
334 }
335}
336
337impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
338 type Error = CompletionError;
339
340 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
341 let choice = response.choices.first().ok_or_else(|| {
342 CompletionError::ResponseError("Response contained no choices".to_owned())
343 })?;
344 let content = match &choice.message {
345 Message::Assistant {
346 content,
347 tool_calls,
348 ..
349 } => {
350 let mut content = if content.trim().is_empty() {
351 vec![]
352 } else {
353 vec![completion::AssistantContent::text(content)]
354 };
355
356 content.extend(
357 tool_calls
358 .iter()
359 .map(|call| {
360 completion::AssistantContent::tool_call(
361 &call.id,
362 &call.function.name,
363 call.function.arguments.clone(),
364 )
365 })
366 .collect::<Vec<_>>(),
367 );
368 Ok(content)
369 }
370 _ => Err(CompletionError::ResponseError(
371 "Response did not contain a valid message or tool call".into(),
372 )),
373 }?;
374
375 let choice = OneOrMany::many(content).map_err(|_| {
376 CompletionError::ResponseError(
377 "Response contained no message or tool call (empty)".to_owned(),
378 )
379 })?;
380
381 Ok(completion::CompletionResponse {
382 choice,
383 raw_response: response,
384 })
385 }
386}
387
388#[derive(Clone)]
390pub struct DeepSeekCompletionModel {
391 pub client: Client,
392 pub model: String,
393}
394
395impl DeepSeekCompletionModel {
396 fn create_completion_request(
397 &self,
398 completion_request: CompletionRequest,
399 ) -> Result<serde_json::Value, CompletionError> {
400 let mut partial_history = vec![];
402
403 if let Some(docs) = completion_request.normalized_documents() {
404 partial_history.push(docs);
405 }
406
407 partial_history.extend(completion_request.chat_history);
408
409 let mut full_history: Vec<Message> = completion_request
411 .preamble
412 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
413
414 full_history.extend(
416 partial_history
417 .into_iter()
418 .map(message::Message::try_into)
419 .collect::<Result<Vec<Vec<Message>>, _>>()?
420 .into_iter()
421 .flatten()
422 .collect::<Vec<_>>(),
423 );
424
425 let request = if completion_request.tools.is_empty() {
426 json!({
427 "model": self.model,
428 "messages": full_history,
429 "temperature": completion_request.temperature,
430 })
431 } else {
432 json!({
433 "model": self.model,
434 "messages": full_history,
435 "temperature": completion_request.temperature,
436 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
437 "tool_choice": "auto",
438 })
439 };
440
441 let request = if let Some(params) = completion_request.additional_params {
442 json_utils::merge(request, params)
443 } else {
444 request
445 };
446
447 Ok(request)
448 }
449}
450
451impl CompletionModel for DeepSeekCompletionModel {
452 type Response = CompletionResponse;
453 type StreamingResponse = openai::StreamingCompletionResponse;
454
455 #[cfg_attr(feature = "worker", worker::send)]
456 async fn completion(
457 &self,
458 completion_request: CompletionRequest,
459 ) -> Result<
460 completion::CompletionResponse<CompletionResponse>,
461 crate::completion::CompletionError,
462 > {
463 let request = self.create_completion_request(completion_request)?;
464
465 tracing::debug!("DeepSeek completion request: {request:?}");
466
467 let response = self
468 .client
469 .post("/chat/completions")
470 .json(&request)
471 .send()
472 .await?;
473
474 if response.status().is_success() {
475 let t = response.text().await?;
476 tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
477
478 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
479 ApiResponse::Ok(response) => response.try_into(),
480 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
481 }
482 } else {
483 Err(CompletionError::ProviderError(response.text().await?))
484 }
485 }
486
487 #[cfg_attr(feature = "worker", worker::send)]
488 async fn stream(
489 &self,
490 completion_request: CompletionRequest,
491 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
492 let mut request = self.create_completion_request(completion_request)?;
493
494 request = merge(
495 request,
496 json!({"stream": true, "stream_options": {"include_usage": true}}),
497 );
498
499 let builder = self.client.post("/v1/chat/completions").json(&request);
500 send_compatible_streaming_request(builder).await
501 }
502}
503
504pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
510pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
512
513#[cfg(test)]
515mod tests {
516
517 use super::*;
518
519 #[test]
520 fn test_deserialize_vec_choice() {
521 let data = r#"[{
522 "finish_reason": "stop",
523 "index": 0,
524 "logprobs": null,
525 "message":{"role":"assistant","content":"Hello, world!"}
526 }]"#;
527
528 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
529 assert_eq!(choices.len(), 1);
530 match &choices.first().unwrap().message {
531 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
532 _ => panic!("Expected assistant message"),
533 }
534 }
535
536 #[test]
537 fn test_deserialize_deepseek_response() {
538 let data = r#"{"choices":[{
539 "finish_reason": "stop",
540 "index": 0,
541 "logprobs": null,
542 "message":{"role":"assistant","content":"Hello, world!"}
543 }]}"#;
544
545 let jd = &mut serde_json::Deserializer::from_str(data);
546 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
547 match result {
548 Ok(response) => match &response.choices.first().unwrap().message {
549 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
550 _ => panic!("Expected assistant message"),
551 },
552 Err(err) => {
553 panic!("Deserialization error at {}: {}", err.path(), err);
554 }
555 }
556 }
557
558 #[test]
559 fn test_deserialize_example_response() {
560 let data = r#"
561 {
562 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
563 "object": "chat.completion",
564 "created": 0,
565 "model": "deepseek-chat",
566 "choices": [
567 {
568 "index": 0,
569 "message": {
570 "role": "assistant",
571 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
572 },
573 "logprobs": null,
574 "finish_reason": "stop"
575 }
576 ],
577 "usage": {
578 "prompt_tokens": 13,
579 "completion_tokens": 32,
580 "total_tokens": 45,
581 "prompt_tokens_details": {
582 "cached_tokens": 0
583 },
584 "prompt_cache_hit_tokens": 0,
585 "prompt_cache_miss_tokens": 13
586 },
587 "system_fingerprint": "fp_4b6881f2c5"
588 }
589 "#;
590 let jd = &mut serde_json::Deserializer::from_str(data);
591 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
592
593 match result {
594 Ok(response) => match &response.choices.first().unwrap().message {
595 Message::Assistant { content, .. } => assert_eq!(
596 content,
597 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
598 ),
599 _ => panic!("Expected assistant message"),
600 },
601 Err(err) => {
602 panic!("Deserialization error at {}: {}", err.path(), err);
603 }
604 }
605 }
606
607 #[test]
608 fn test_serialize_deserialize_tool_call_message() {
609 let tool_call_choice_json = r#"
610 {
611 "finish_reason": "tool_calls",
612 "index": 0,
613 "logprobs": null,
614 "message": {
615 "content": "",
616 "role": "assistant",
617 "tool_calls": [
618 {
619 "function": {
620 "arguments": "{\"x\":2,\"y\":5}",
621 "name": "subtract"
622 },
623 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
624 "index": 0,
625 "type": "function"
626 }
627 ]
628 }
629 }
630 "#;
631
632 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
633
634 let expected_choice: Choice = Choice {
635 finish_reason: "tool_calls".to_string(),
636 index: 0,
637 logprobs: None,
638 message: Message::Assistant {
639 content: "".to_string(),
640 name: None,
641 tool_calls: vec![ToolCall {
642 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
643 function: Function {
644 name: "subtract".to_string(),
645 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
646 },
647 index: 0,
648 r#type: ToolType::Function,
649 }],
650 },
651 };
652
653 assert_eq!(choice, expected_choice);
654 }
655}