1use crate::json_utils::merge;
13use crate::providers::openai::send_compatible_streaming_request;
14use crate::streaming::{StreamingCompletionModel, StreamingResult};
15use crate::{
16 completion::{self, CompletionError, CompletionModel, CompletionRequest},
17 extractor::ExtractorBuilder,
18 json_utils, message, OneOrMany,
19};
20use reqwest::Client as HttpClient;
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
29
30#[derive(Clone)]
31pub struct Client {
32 pub base_url: String,
33 http_client: HttpClient,
34}
35
36impl Client {
37 pub fn new(api_key: &str) -> Self {
39 Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
40 }
41
42 pub fn from_env() -> Self {
44 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
45 Self::new(&api_key)
46 }
47
48 pub fn from_url(api_key: &str, base_url: &str) -> Self {
50 Self {
52 base_url: base_url.to_string(),
53 http_client: reqwest::Client::builder()
54 .default_headers({
55 let mut headers = reqwest::header::HeaderMap::new();
56 headers.insert(
57 "Authorization",
58 format!("Bearer {}", api_key)
59 .parse()
60 .expect("Bearer token should parse"),
61 );
62 headers
63 })
64 .build()
65 .expect("DeepSeek reqwest client should build"),
66 }
67 }
68
69 fn post(&self, path: &str) -> reqwest::RequestBuilder {
70 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
71 self.http_client.post(url)
72 }
73
74 pub fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
76 DeepSeekCompletionModel {
77 client: self.clone(),
78 model: model_name.to_string(),
79 }
80 }
81
82 pub fn agent(&self, model_name: &str) -> crate::agent::AgentBuilder<DeepSeekCompletionModel> {
84 crate::agent::AgentBuilder::new(self.completion_model(model_name))
85 }
86
87 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
89 &self,
90 model: &str,
91 ) -> ExtractorBuilder<T, DeepSeekCompletionModel> {
92 ExtractorBuilder::new(self.completion_model(model))
93 }
94}
95
96#[derive(Debug, Deserialize)]
97struct ApiErrorResponse {
98 message: String,
99}
100
101#[derive(Debug, Deserialize)]
102#[serde(untagged)]
103enum ApiResponse<T> {
104 Ok(T),
105 Err(ApiErrorResponse),
106}
107
108impl From<ApiErrorResponse> for CompletionError {
109 fn from(err: ApiErrorResponse) -> Self {
110 CompletionError::ProviderError(err.message)
111 }
112}
113
114#[derive(Clone, Debug, Serialize, Deserialize)]
116pub struct CompletionResponse {
117 pub choices: Vec<Choice>,
119 }
121
122#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
123pub struct Choice {
124 pub index: usize,
125 pub message: Message,
126 pub logprobs: Option<serde_json::Value>,
127 pub finish_reason: String,
128}
129
130#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
131#[serde(tag = "role", rename_all = "lowercase")]
132pub enum Message {
133 System {
134 content: String,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 name: Option<String>,
137 },
138 User {
139 content: String,
140 #[serde(skip_serializing_if = "Option::is_none")]
141 name: Option<String>,
142 },
143 Assistant {
144 content: String,
145 #[serde(skip_serializing_if = "Option::is_none")]
146 name: Option<String>,
147 #[serde(
148 default,
149 deserialize_with = "json_utils::null_or_vec",
150 skip_serializing_if = "Vec::is_empty"
151 )]
152 tool_calls: Vec<ToolCall>,
153 },
154 #[serde(rename = "tool")]
155 ToolResult {
156 tool_call_id: String,
157 content: String,
158 },
159}
160
161impl Message {
162 pub fn system(content: &str) -> Self {
163 Message::System {
164 content: content.to_owned(),
165 name: None,
166 }
167 }
168}
169
170impl From<message::ToolResult> for Message {
171 fn from(tool_result: message::ToolResult) -> Self {
172 let content = match tool_result.content.first() {
173 message::ToolResultContent::Text(text) => text.text,
174 message::ToolResultContent::Image(_) => String::from("[Image]"),
175 };
176
177 Message::ToolResult {
178 tool_call_id: tool_result.id,
179 content,
180 }
181 }
182}
183
184impl From<message::ToolCall> for ToolCall {
185 fn from(tool_call: message::ToolCall) -> Self {
186 Self {
187 id: tool_call.id,
188 index: 0,
190 r#type: ToolType::Function,
191 function: Function {
192 name: tool_call.function.name,
193 arguments: tool_call.function.arguments,
194 },
195 }
196 }
197}
198
199impl TryFrom<message::Message> for Vec<Message> {
200 type Error = message::MessageError;
201
202 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
203 match message {
204 message::Message::User { content } => {
205 let mut messages = vec![];
207
208 let tool_results = content
209 .clone()
210 .into_iter()
211 .filter_map(|content| match content {
212 message::UserContent::ToolResult(tool_result) => {
213 Some(Message::from(tool_result))
214 }
215 _ => None,
216 })
217 .collect::<Vec<_>>();
218
219 messages.extend(tool_results);
220
221 let text_messages = content
223 .into_iter()
224 .filter_map(|content| match content {
225 message::UserContent::Text(text) => Some(Message::User {
226 content: text.text,
227 name: None,
228 }),
229 _ => None,
230 })
231 .collect::<Vec<_>>();
232 messages.extend(text_messages);
233
234 Ok(messages)
235 }
236 message::Message::Assistant { content } => {
237 let mut messages: Vec<Message> = vec![];
238
239 let tool_calls = content
241 .clone()
242 .into_iter()
243 .filter_map(|content| match content {
244 message::AssistantContent::ToolCall(tool_call) => {
245 Some(ToolCall::from(tool_call))
246 }
247 _ => None,
248 })
249 .collect::<Vec<_>>();
250
251 if !tool_calls.is_empty() {
253 messages.push(Message::Assistant {
254 content: "".to_string(),
255 name: None,
256 tool_calls,
257 });
258 }
259
260 let text_content = content
262 .into_iter()
263 .filter_map(|content| match content {
264 message::AssistantContent::Text(text) => Some(Message::Assistant {
265 content: text.text,
266 name: None,
267 tool_calls: vec![],
268 }),
269 _ => None,
270 })
271 .collect::<Vec<_>>();
272
273 messages.extend(text_content);
274
275 Ok(messages)
276 }
277 }
278 }
279}
280
281#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
282pub struct ToolCall {
283 pub id: String,
284 pub index: usize,
285 #[serde(default)]
286 pub r#type: ToolType,
287 pub function: Function,
288}
289
290#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
291pub struct Function {
292 pub name: String,
293 #[serde(with = "json_utils::stringified_json")]
294 pub arguments: serde_json::Value,
295}
296
297#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
298#[serde(rename_all = "lowercase")]
299pub enum ToolType {
300 #[default]
301 Function,
302}
303
304#[derive(Clone, Debug, Deserialize, Serialize)]
305pub struct ToolDefinition {
306 pub r#type: String,
307 pub function: completion::ToolDefinition,
308}
309
310impl From<crate::completion::ToolDefinition> for ToolDefinition {
311 fn from(tool: crate::completion::ToolDefinition) -> Self {
312 Self {
313 r#type: "function".into(),
314 function: tool,
315 }
316 }
317}
318
319impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
320 type Error = CompletionError;
321
322 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
323 let choice = response.choices.first().ok_or_else(|| {
324 CompletionError::ResponseError("Response contained no choices".to_owned())
325 })?;
326 let content = match &choice.message {
327 Message::Assistant {
328 content,
329 tool_calls,
330 ..
331 } => {
332 let mut content = if content.trim().is_empty() {
333 vec![]
334 } else {
335 vec![completion::AssistantContent::text(content)]
336 };
337
338 content.extend(
339 tool_calls
340 .iter()
341 .map(|call| {
342 completion::AssistantContent::tool_call(
343 &call.id,
344 &call.function.name,
345 call.function.arguments.clone(),
346 )
347 })
348 .collect::<Vec<_>>(),
349 );
350 Ok(content)
351 }
352 _ => Err(CompletionError::ResponseError(
353 "Response did not contain a valid message or tool call".into(),
354 )),
355 }?;
356
357 let choice = OneOrMany::many(content).map_err(|_| {
358 CompletionError::ResponseError(
359 "Response contained no message or tool call (empty)".to_owned(),
360 )
361 })?;
362
363 Ok(completion::CompletionResponse {
364 choice,
365 raw_response: response,
366 })
367 }
368}
369
370#[derive(Clone)]
372pub struct DeepSeekCompletionModel {
373 pub client: Client,
374 pub model: String,
375}
376
377impl DeepSeekCompletionModel {
378 fn create_completion_request(
379 &self,
380 completion_request: CompletionRequest,
381 ) -> Result<serde_json::Value, CompletionError> {
382 let mut partial_history = vec![];
384 if let Some(docs) = completion_request.normalized_documents() {
385 partial_history.push(docs);
386 }
387 partial_history.extend(completion_request.chat_history);
388
389 let mut full_history: Vec<Message> = completion_request
391 .preamble
392 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
393
394 full_history.extend(
396 partial_history
397 .into_iter()
398 .map(message::Message::try_into)
399 .collect::<Result<Vec<Vec<Message>>, _>>()?
400 .into_iter()
401 .flatten()
402 .collect::<Vec<_>>(),
403 );
404
405 let request = if completion_request.tools.is_empty() {
406 json!({
407 "model": self.model,
408 "messages": full_history,
409 "temperature": completion_request.temperature,
410 })
411 } else {
412 json!({
413 "model": self.model,
414 "messages": full_history,
415 "temperature": completion_request.temperature,
416 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
417 "tool_choice": "auto",
418 })
419 };
420
421 let request = if let Some(params) = completion_request.additional_params {
422 json_utils::merge(request, params)
423 } else {
424 request
425 };
426
427 Ok(request)
428 }
429}
430
431impl CompletionModel for DeepSeekCompletionModel {
432 type Response = CompletionResponse;
433
434 #[cfg_attr(feature = "worker", worker::send)]
435 async fn completion(
436 &self,
437 completion_request: CompletionRequest,
438 ) -> Result<
439 completion::CompletionResponse<CompletionResponse>,
440 crate::completion::CompletionError,
441 > {
442 let request = self.create_completion_request(completion_request)?;
443
444 let response = self
445 .client
446 .post("/chat/completions")
447 .json(&request)
448 .send()
449 .await?;
450
451 if response.status().is_success() {
452 let t = response.text().await?;
453 tracing::debug!(target: "rig", "DeepSeek completion: {}", t);
454
455 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
456 ApiResponse::Ok(response) => response.try_into(),
457 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
458 }
459 } else {
460 Err(CompletionError::ProviderError(response.text().await?))
461 }
462 }
463}
464
465impl StreamingCompletionModel for DeepSeekCompletionModel {
466 async fn stream(
467 &self,
468 completion_request: CompletionRequest,
469 ) -> Result<StreamingResult, CompletionError> {
470 let mut request = self.create_completion_request(completion_request)?;
471
472 request = merge(request, json!({"stream": true}));
473
474 let builder = self.client.post("/v1/chat/completions").json(&request);
475 send_compatible_streaming_request(builder).await
476 }
477}
478
479pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
485pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
487
488#[cfg(test)]
490mod tests {
491
492 use super::*;
493
494 #[test]
495 fn test_deserialize_vec_choice() {
496 let data = r#"[{
497 "finish_reason": "stop",
498 "index": 0,
499 "logprobs": null,
500 "message":{"role":"assistant","content":"Hello, world!"}
501 }]"#;
502
503 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
504 assert_eq!(choices.len(), 1);
505 match &choices.first().unwrap().message {
506 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
507 _ => panic!("Expected assistant message"),
508 }
509 }
510
511 #[test]
512 fn test_deserialize_deepseek_response() {
513 let data = r#"{"choices":[{
514 "finish_reason": "stop",
515 "index": 0,
516 "logprobs": null,
517 "message":{"role":"assistant","content":"Hello, world!"}
518 }]}"#;
519
520 let jd = &mut serde_json::Deserializer::from_str(data);
521 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
522 match result {
523 Ok(response) => match &response.choices.first().unwrap().message {
524 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
525 _ => panic!("Expected assistant message"),
526 },
527 Err(err) => {
528 panic!("Deserialization error at {}: {}", err.path(), err);
529 }
530 }
531 }
532
533 #[test]
534 fn test_deserialize_example_response() {
535 let data = r#"
536 {
537 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
538 "object": "chat.completion",
539 "created": 0,
540 "model": "deepseek-chat",
541 "choices": [
542 {
543 "index": 0,
544 "message": {
545 "role": "assistant",
546 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
547 },
548 "logprobs": null,
549 "finish_reason": "stop"
550 }
551 ],
552 "usage": {
553 "prompt_tokens": 13,
554 "completion_tokens": 32,
555 "total_tokens": 45,
556 "prompt_tokens_details": {
557 "cached_tokens": 0
558 },
559 "prompt_cache_hit_tokens": 0,
560 "prompt_cache_miss_tokens": 13
561 },
562 "system_fingerprint": "fp_4b6881f2c5"
563 }
564 "#;
565 let jd = &mut serde_json::Deserializer::from_str(data);
566 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
567
568 match result {
569 Ok(response) => match &response.choices.first().unwrap().message {
570 Message::Assistant { content, .. } => assert_eq!(
571 content,
572 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
573 ),
574 _ => panic!("Expected assistant message"),
575 },
576 Err(err) => {
577 panic!("Deserialization error at {}: {}", err.path(), err);
578 }
579 }
580 }
581
582 #[test]
583 fn test_serialize_deserialize_tool_call_message() {
584 let tool_call_choice_json = r#"
585 {
586 "finish_reason": "tool_calls",
587 "index": 0,
588 "logprobs": null,
589 "message": {
590 "content": "",
591 "role": "assistant",
592 "tool_calls": [
593 {
594 "function": {
595 "arguments": "{\"x\":2,\"y\":5}",
596 "name": "subtract"
597 },
598 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
599 "index": 0,
600 "type": "function"
601 }
602 ]
603 }
604 }
605 "#;
606
607 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
608
609 let expected_choice: Choice = Choice {
610 finish_reason: "tool_calls".to_string(),
611 index: 0,
612 logprobs: None,
613 message: Message::Assistant {
614 content: "".to_string(),
615 name: None,
616 tool_calls: vec![ToolCall {
617 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
618 function: Function {
619 name: "subtract".to_string(),
620 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
621 },
622 index: 0,
623 r#type: ToolType::Function,
624 }],
625 },
626 };
627
628 assert_eq!(choice, expected_choice);
629 }
630}