1use crate::client::{CompletionClient, ProviderClient};
13use crate::json_utils::merge;
14use crate::providers::openai;
15use crate::providers::openai::send_compatible_streaming_request;
16use crate::streaming::StreamingCompletionResponse;
17use crate::{
18 completion::{self, CompletionError, CompletionModel, CompletionRequest},
19 impl_conversion_traits, json_utils, message, OneOrMany,
20};
21use reqwest::Client as HttpClient;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
29
30#[derive(Clone, Debug)]
31pub struct Client {
32 pub base_url: String,
33 http_client: HttpClient,
34}
35
36impl Client {
37 pub fn new(api_key: &str) -> Self {
39 Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
40 }
41
42 pub fn from_url(api_key: &str, base_url: &str) -> Self {
44 Self {
46 base_url: base_url.to_string(),
47 http_client: reqwest::Client::builder()
48 .default_headers({
49 let mut headers = reqwest::header::HeaderMap::new();
50 headers.insert(
51 "Authorization",
52 format!("Bearer {api_key}")
53 .parse()
54 .expect("Bearer token should parse"),
55 );
56 headers
57 })
58 .build()
59 .expect("DeepSeek reqwest client should build"),
60 }
61 }
62
63 fn post(&self, path: &str) -> reqwest::RequestBuilder {
64 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
65 self.http_client.post(url)
66 }
67}
68
69impl ProviderClient for Client {
70 fn from_env() -> Self {
72 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
73 Self::new(&api_key)
74 }
75}
76
77impl CompletionClient for Client {
78 type CompletionModel = DeepSeekCompletionModel;
79
80 fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
82 DeepSeekCompletionModel {
83 client: self.clone(),
84 model: model_name.to_string(),
85 }
86 }
87}
88
89impl_conversion_traits!(
90 AsEmbeddings,
91 AsTranscription,
92 AsImageGeneration,
93 AsAudioGeneration for Client
94);
95
96#[derive(Debug, Deserialize)]
97struct ApiErrorResponse {
98 message: String,
99}
100
101#[derive(Debug, Deserialize)]
102#[serde(untagged)]
103enum ApiResponse<T> {
104 Ok(T),
105 Err(ApiErrorResponse),
106}
107
108impl From<ApiErrorResponse> for CompletionError {
109 fn from(err: ApiErrorResponse) -> Self {
110 CompletionError::ProviderError(err.message)
111 }
112}
113
114#[derive(Clone, Debug, Serialize, Deserialize)]
116pub struct CompletionResponse {
117 pub choices: Vec<Choice>,
119 }
121
122#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
123pub struct Choice {
124 pub index: usize,
125 pub message: Message,
126 pub logprobs: Option<serde_json::Value>,
127 pub finish_reason: String,
128}
129
130#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
131#[serde(tag = "role", rename_all = "lowercase")]
132pub enum Message {
133 System {
134 content: String,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 name: Option<String>,
137 },
138 User {
139 content: String,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 name: Option<String>,
142 },
143 Assistant {
144 content: String,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 name: Option<String>,
147 #[serde(
148 default,
149 deserialize_with = "json_utils::null_or_vec",
150 skip_serializing_if = "Vec::is_empty"
151 )]
152 tool_calls: Vec<ToolCall>,
153 },
154 #[serde(rename = "tool")]
155 ToolResult {
156 tool_call_id: String,
157 content: String,
158 },
159}
160
161impl Message {
162 pub fn system(content: &str) -> Self {
163 Message::System {
164 content: content.to_owned(),
165 name: None,
166 }
167 }
168}
169
170impl From<message::ToolResult> for Message {
171 fn from(tool_result: message::ToolResult) -> Self {
172 let content = match tool_result.content.first() {
173 message::ToolResultContent::Text(text) => text.text,
174 message::ToolResultContent::Image(_) => String::from("[Image]"),
175 };
176
177 Message::ToolResult {
178 tool_call_id: tool_result.id,
179 content,
180 }
181 }
182}
183
184impl From<message::ToolCall> for ToolCall {
185 fn from(tool_call: message::ToolCall) -> Self {
186 Self {
187 id: tool_call.id,
188 index: 0,
190 r#type: ToolType::Function,
191 function: Function {
192 name: tool_call.function.name,
193 arguments: tool_call.function.arguments,
194 },
195 }
196 }
197}
198
199impl TryFrom<message::Message> for Vec<Message> {
200 type Error = message::MessageError;
201
202 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
203 match message {
204 message::Message::User { content } => {
205 let mut messages = vec![];
207
208 let tool_results = content
209 .clone()
210 .into_iter()
211 .filter_map(|content| match content {
212 message::UserContent::ToolResult(tool_result) => {
213 Some(Message::from(tool_result))
214 }
215 _ => None,
216 })
217 .collect::<Vec<_>>();
218
219 messages.extend(tool_results);
220
221 let text_messages = content
223 .into_iter()
224 .filter_map(|content| match content {
225 message::UserContent::Text(text) => Some(Message::User {
226 content: text.text,
227 name: None,
228 }),
229 _ => None,
230 })
231 .collect::<Vec<_>>();
232 messages.extend(text_messages);
233
234 Ok(messages)
235 }
236 message::Message::Assistant { content } => {
237 let mut messages: Vec<Message> = vec![];
238
239 let tool_calls = content
241 .clone()
242 .into_iter()
243 .filter_map(|content| match content {
244 message::AssistantContent::ToolCall(tool_call) => {
245 Some(ToolCall::from(tool_call))
246 }
247 _ => None,
248 })
249 .collect::<Vec<_>>();
250
251 if !tool_calls.is_empty() {
253 messages.push(Message::Assistant {
254 content: "".to_string(),
255 name: None,
256 tool_calls,
257 });
258 }
259
260 let text_content = content
262 .into_iter()
263 .filter_map(|content| match content {
264 message::AssistantContent::Text(text) => Some(Message::Assistant {
265 content: text.text,
266 name: None,
267 tool_calls: vec![],
268 }),
269 _ => None,
270 })
271 .collect::<Vec<_>>();
272
273 messages.extend(text_content);
274
275 Ok(messages)
276 }
277 }
278 }
279}
280
281#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
282pub struct ToolCall {
283 pub id: String,
284 pub index: usize,
285 #[serde(default)]
286 pub r#type: ToolType,
287 pub function: Function,
288}
289
290#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
291pub struct Function {
292 pub name: String,
293 #[serde(with = "json_utils::stringified_json")]
294 pub arguments: serde_json::Value,
295}
296
297#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
298#[serde(rename_all = "lowercase")]
299pub enum ToolType {
300 #[default]
301 Function,
302}
303
304#[derive(Clone, Debug, Deserialize, Serialize)]
305pub struct ToolDefinition {
306 pub r#type: String,
307 pub function: completion::ToolDefinition,
308}
309
310impl From<crate::completion::ToolDefinition> for ToolDefinition {
311 fn from(tool: crate::completion::ToolDefinition) -> Self {
312 Self {
313 r#type: "function".into(),
314 function: tool,
315 }
316 }
317}
318
319impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
320 type Error = CompletionError;
321
322 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
323 let choice = response.choices.first().ok_or_else(|| {
324 CompletionError::ResponseError("Response contained no choices".to_owned())
325 })?;
326 let content = match &choice.message {
327 Message::Assistant {
328 content,
329 tool_calls,
330 ..
331 } => {
332 let mut content = if content.trim().is_empty() {
333 vec![]
334 } else {
335 vec![completion::AssistantContent::text(content)]
336 };
337
338 content.extend(
339 tool_calls
340 .iter()
341 .map(|call| {
342 completion::AssistantContent::tool_call(
343 &call.id,
344 &call.function.name,
345 call.function.arguments.clone(),
346 )
347 })
348 .collect::<Vec<_>>(),
349 );
350 Ok(content)
351 }
352 _ => Err(CompletionError::ResponseError(
353 "Response did not contain a valid message or tool call".into(),
354 )),
355 }?;
356
357 let choice = OneOrMany::many(content).map_err(|_| {
358 CompletionError::ResponseError(
359 "Response contained no message or tool call (empty)".to_owned(),
360 )
361 })?;
362
363 Ok(completion::CompletionResponse {
364 choice,
365 raw_response: response,
366 })
367 }
368}
369
370#[derive(Clone)]
372pub struct DeepSeekCompletionModel {
373 pub client: Client,
374 pub model: String,
375}
376
377impl DeepSeekCompletionModel {
378 fn create_completion_request(
379 &self,
380 completion_request: CompletionRequest,
381 ) -> Result<serde_json::Value, CompletionError> {
382 let mut partial_history = vec![];
384 if let Some(docs) = completion_request.normalized_documents() {
385 partial_history.push(docs);
386 }
387 partial_history.extend(completion_request.chat_history);
388
389 let mut full_history: Vec<Message> = completion_request
391 .preamble
392 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
393
394 full_history.extend(
396 partial_history
397 .into_iter()
398 .map(message::Message::try_into)
399 .collect::<Result<Vec<Vec<Message>>, _>>()?
400 .into_iter()
401 .flatten()
402 .collect::<Vec<_>>(),
403 );
404
405 let request = if completion_request.tools.is_empty() {
406 json!({
407 "model": self.model,
408 "messages": full_history,
409 "temperature": completion_request.temperature,
410 })
411 } else {
412 json!({
413 "model": self.model,
414 "messages": full_history,
415 "temperature": completion_request.temperature,
416 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
417 "tool_choice": "auto",
418 })
419 };
420
421 let request = if let Some(params) = completion_request.additional_params {
422 json_utils::merge(request, params)
423 } else {
424 request
425 };
426
427 Ok(request)
428 }
429}
430
431impl CompletionModel for DeepSeekCompletionModel {
432 type Response = CompletionResponse;
433 type StreamingResponse = openai::StreamingCompletionResponse;
434
435 #[cfg_attr(feature = "worker", worker::send)]
436 async fn completion(
437 &self,
438 completion_request: CompletionRequest,
439 ) -> Result<
440 completion::CompletionResponse<CompletionResponse>,
441 crate::completion::CompletionError,
442 > {
443 let request = self.create_completion_request(completion_request)?;
444
445 let response = self
446 .client
447 .post("/chat/completions")
448 .json(&request)
449 .send()
450 .await?;
451
452 if response.status().is_success() {
453 let t = response.text().await?;
454 tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
455
456 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
457 ApiResponse::Ok(response) => response.try_into(),
458 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
459 }
460 } else {
461 Err(CompletionError::ProviderError(response.text().await?))
462 }
463 }
464
465 #[cfg_attr(feature = "worker", worker::send)]
466 async fn stream(
467 &self,
468 completion_request: CompletionRequest,
469 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
470 let mut request = self.create_completion_request(completion_request)?;
471
472 request = merge(
473 request,
474 json!({"stream": true, "stream_options": {"include_usage": true}}),
475 );
476
477 let builder = self.client.post("/v1/chat/completions").json(&request);
478 send_compatible_streaming_request(builder).await
479 }
480}
481
482pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
488pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
490
491#[cfg(test)]
493mod tests {
494
495 use super::*;
496
497 #[test]
498 fn test_deserialize_vec_choice() {
499 let data = r#"[{
500 "finish_reason": "stop",
501 "index": 0,
502 "logprobs": null,
503 "message":{"role":"assistant","content":"Hello, world!"}
504 }]"#;
505
506 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
507 assert_eq!(choices.len(), 1);
508 match &choices.first().unwrap().message {
509 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
510 _ => panic!("Expected assistant message"),
511 }
512 }
513
514 #[test]
515 fn test_deserialize_deepseek_response() {
516 let data = r#"{"choices":[{
517 "finish_reason": "stop",
518 "index": 0,
519 "logprobs": null,
520 "message":{"role":"assistant","content":"Hello, world!"}
521 }]}"#;
522
523 let jd = &mut serde_json::Deserializer::from_str(data);
524 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
525 match result {
526 Ok(response) => match &response.choices.first().unwrap().message {
527 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
528 _ => panic!("Expected assistant message"),
529 },
530 Err(err) => {
531 panic!("Deserialization error at {}: {}", err.path(), err);
532 }
533 }
534 }
535
536 #[test]
537 fn test_deserialize_example_response() {
538 let data = r#"
539 {
540 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
541 "object": "chat.completion",
542 "created": 0,
543 "model": "deepseek-chat",
544 "choices": [
545 {
546 "index": 0,
547 "message": {
548 "role": "assistant",
549 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
550 },
551 "logprobs": null,
552 "finish_reason": "stop"
553 }
554 ],
555 "usage": {
556 "prompt_tokens": 13,
557 "completion_tokens": 32,
558 "total_tokens": 45,
559 "prompt_tokens_details": {
560 "cached_tokens": 0
561 },
562 "prompt_cache_hit_tokens": 0,
563 "prompt_cache_miss_tokens": 13
564 },
565 "system_fingerprint": "fp_4b6881f2c5"
566 }
567 "#;
568 let jd = &mut serde_json::Deserializer::from_str(data);
569 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
570
571 match result {
572 Ok(response) => match &response.choices.first().unwrap().message {
573 Message::Assistant { content, .. } => assert_eq!(
574 content,
575 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
576 ),
577 _ => panic!("Expected assistant message"),
578 },
579 Err(err) => {
580 panic!("Deserialization error at {}: {}", err.path(), err);
581 }
582 }
583 }
584
585 #[test]
586 fn test_serialize_deserialize_tool_call_message() {
587 let tool_call_choice_json = r#"
588 {
589 "finish_reason": "tool_calls",
590 "index": 0,
591 "logprobs": null,
592 "message": {
593 "content": "",
594 "role": "assistant",
595 "tool_calls": [
596 {
597 "function": {
598 "arguments": "{\"x\":2,\"y\":5}",
599 "name": "subtract"
600 },
601 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
602 "index": 0,
603 "type": "function"
604 }
605 ]
606 }
607 }
608 "#;
609
610 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
611
612 let expected_choice: Choice = Choice {
613 finish_reason: "tool_calls".to_string(),
614 index: 0,
615 logprobs: None,
616 message: Message::Assistant {
617 content: "".to_string(),
618 name: None,
619 tool_calls: vec![ToolCall {
620 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
621 function: Function {
622 name: "subtract".to_string(),
623 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
624 },
625 index: 0,
626 r#type: ToolType::Function,
627 }],
628 },
629 };
630
631 assert_eq!(choice, expected_choice);
632 }
633}