1use crate::{
12 completion::{self, CompletionError, CompletionModel, CompletionRequest},
13 extractor::ExtractorBuilder,
14 json_utils, message, OneOrMany,
15};
16use reqwest::Client as HttpClient;
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20
21const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com";
25
26#[derive(Clone)]
27pub struct Client {
28 pub base_url: String,
29 http_client: HttpClient,
30}
31
32impl Client {
33 pub fn new(api_key: &str) -> Self {
35 Self::from_url(api_key, DEEPSEEK_API_BASE_URL)
36 }
37
38 pub fn from_env() -> Self {
40 let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set");
41 Self::new(&api_key)
42 }
43
44 pub fn from_url(api_key: &str, base_url: &str) -> Self {
46 Self {
48 base_url: base_url.to_string(),
49 http_client: reqwest::Client::builder()
50 .default_headers({
51 let mut headers = reqwest::header::HeaderMap::new();
52 headers.insert(
53 "Authorization",
54 format!("Bearer {}", api_key)
55 .parse()
56 .expect("Bearer token should parse"),
57 );
58 headers
59 })
60 .build()
61 .expect("DeepSeek reqwest client should build"),
62 }
63 }
64
65 fn post(&self, path: &str) -> reqwest::RequestBuilder {
66 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
67 self.http_client.post(url)
68 }
69
70 pub fn completion_model(&self, model_name: &str) -> DeepSeekCompletionModel {
72 DeepSeekCompletionModel {
73 client: self.clone(),
74 model: model_name.to_string(),
75 }
76 }
77
78 pub fn agent(&self, model_name: &str) -> crate::agent::AgentBuilder<DeepSeekCompletionModel> {
80 crate::agent::AgentBuilder::new(self.completion_model(model_name))
81 }
82
83 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
85 &self,
86 model: &str,
87 ) -> ExtractorBuilder<T, DeepSeekCompletionModel> {
88 ExtractorBuilder::new(self.completion_model(model))
89 }
90}
91
92#[derive(Debug, Deserialize)]
93struct ApiErrorResponse {
94 message: String,
95}
96
97#[derive(Debug, Deserialize)]
98#[serde(untagged)]
99enum ApiResponse<T> {
100 Ok(T),
101 Err(ApiErrorResponse),
102}
103
104impl From<ApiErrorResponse> for CompletionError {
105 fn from(err: ApiErrorResponse) -> Self {
106 CompletionError::ProviderError(err.message)
107 }
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize)]
112pub struct CompletionResponse {
113 pub choices: Vec<Choice>,
115 }
117
118#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
119pub struct Choice {
120 pub index: usize,
121 pub message: Message,
122 pub logprobs: Option<serde_json::Value>,
123 pub finish_reason: String,
124}
125
126#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
127#[serde(tag = "role", rename_all = "lowercase")]
128pub enum Message {
129 System {
130 content: String,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 name: Option<String>,
133 },
134 User {
135 content: String,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 name: Option<String>,
138 },
139 Assistant {
140 content: String,
141 #[serde(skip_serializing_if = "Option::is_none")]
142 name: Option<String>,
143 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
144 tool_calls: Vec<ToolCall>,
145 },
146 #[serde(rename = "Tool")]
147 ToolResult {
148 tool_call_id: String,
149 content: String,
150 },
151}
152
153impl Message {
154 pub fn system(content: &str) -> Self {
155 Message::System {
156 content: content.to_owned(),
157 name: None,
158 }
159 }
160}
161
162impl From<message::ToolResult> for Message {
163 fn from(tool_result: message::ToolResult) -> Self {
164 let content = match tool_result.content.first() {
165 message::ToolResultContent::Text(text) => text.text,
166 message::ToolResultContent::Image(_) => String::from("[Image]"),
167 };
168
169 Message::ToolResult {
170 tool_call_id: tool_result.id,
171 content,
172 }
173 }
174}
175
176impl From<message::ToolCall> for ToolCall {
177 fn from(tool_call: message::ToolCall) -> Self {
178 Self {
179 id: tool_call.id,
180 index: 0,
182 r#type: ToolType::Function,
183 function: Function {
184 name: tool_call.function.name,
185 arguments: tool_call.function.arguments,
186 },
187 }
188 }
189}
190
191impl TryFrom<message::Message> for Vec<Message> {
192 type Error = message::MessageError;
193
194 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
195 match message {
196 message::Message::User { content } => {
197 let mut messages = vec![];
199
200 let tool_results = content
201 .clone()
202 .into_iter()
203 .filter_map(|content| match content {
204 message::UserContent::ToolResult(tool_result) => {
205 Some(Message::from(tool_result))
206 }
207 _ => None,
208 })
209 .collect::<Vec<_>>();
210
211 messages.extend(tool_results);
212
213 let text_messages = content
215 .into_iter()
216 .filter_map(|content| match content {
217 message::UserContent::Text(text) => Some(Message::User {
218 content: text.text,
219 name: None,
220 }),
221 _ => None,
222 })
223 .collect::<Vec<_>>();
224 messages.extend(text_messages);
225
226 Ok(messages)
227 }
228 message::Message::Assistant { content } => {
229 let mut messages: Vec<Message> = vec![];
230
231 let tool_calls = content
233 .clone()
234 .into_iter()
235 .filter_map(|content| match content {
236 message::AssistantContent::ToolCall(tool_call) => {
237 Some(ToolCall::from(tool_call))
238 }
239 _ => None,
240 })
241 .collect::<Vec<_>>();
242
243 if !tool_calls.is_empty() {
245 messages.push(Message::Assistant {
246 content: "".to_string(),
247 name: None,
248 tool_calls,
249 });
250 }
251
252 let text_content = content
254 .into_iter()
255 .filter_map(|content| match content {
256 message::AssistantContent::Text(text) => Some(Message::Assistant {
257 content: text.text,
258 name: None,
259 tool_calls: vec![],
260 }),
261 _ => None,
262 })
263 .collect::<Vec<_>>();
264
265 messages.extend(text_content);
266
267 Ok(messages)
268 }
269 }
270 }
271}
272
273#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
274pub struct ToolCall {
275 pub id: String,
276 pub index: usize,
277 #[serde(default)]
278 pub r#type: ToolType,
279 pub function: Function,
280}
281
282#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
283pub struct Function {
284 pub name: String,
285 #[serde(with = "json_utils::stringified_json")]
286 pub arguments: serde_json::Value,
287}
288
289#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
290#[serde(rename_all = "lowercase")]
291pub enum ToolType {
292 #[default]
293 Function,
294}
295
296#[derive(Clone, Debug, Deserialize, Serialize)]
297pub struct ToolDefinition {
298 pub r#type: String,
299 pub function: completion::ToolDefinition,
300}
301
302impl From<crate::completion::ToolDefinition> for ToolDefinition {
303 fn from(tool: crate::completion::ToolDefinition) -> Self {
304 Self {
305 r#type: "function".into(),
306 function: tool,
307 }
308 }
309}
310
311impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
312 type Error = CompletionError;
313
314 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
315 let choice = response.choices.first().ok_or_else(|| {
316 CompletionError::ResponseError("Response contained no choices".to_owned())
317 })?;
318 let content = match &choice.message {
319 Message::Assistant {
320 content,
321 tool_calls,
322 ..
323 } => {
324 let mut content = if content.trim().is_empty() {
325 vec![]
326 } else {
327 vec![completion::AssistantContent::text(content)]
328 };
329
330 content.extend(
331 tool_calls
332 .iter()
333 .map(|call| {
334 completion::AssistantContent::tool_call(
335 &call.function.name,
336 &call.function.name,
337 call.function.arguments.clone(),
338 )
339 })
340 .collect::<Vec<_>>(),
341 );
342 Ok(content)
343 }
344 _ => Err(CompletionError::ResponseError(
345 "Response did not contain a valid message or tool call".into(),
346 )),
347 }?;
348
349 let choice = OneOrMany::many(content).map_err(|_| {
350 CompletionError::ResponseError(
351 "Response contained no message or tool call (empty)".to_owned(),
352 )
353 })?;
354
355 Ok(completion::CompletionResponse {
356 choice,
357 raw_response: response,
358 })
359 }
360}
361
362#[derive(Clone)]
364pub struct DeepSeekCompletionModel {
365 pub client: Client,
366 pub model: String,
367}
368
369impl CompletionModel for DeepSeekCompletionModel {
370 type Response = CompletionResponse;
371
372 #[cfg_attr(feature = "worker", worker::send)]
373 async fn completion(
374 &self,
375 completion_request: CompletionRequest,
376 ) -> Result<
377 completion::CompletionResponse<CompletionResponse>,
378 crate::completion::CompletionError,
379 > {
380 let mut full_history: Vec<Message> = match &completion_request.preamble {
382 Some(preamble) => vec![Message::system(preamble)],
383 None => vec![],
384 };
385
386 let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
388
389 let chat_history: Vec<Message> = completion_request
391 .chat_history
392 .into_iter()
393 .map(|message| message.try_into())
394 .collect::<Result<Vec<Vec<Message>>, _>>()?
395 .into_iter()
396 .flatten()
397 .collect();
398
399 full_history.extend(chat_history);
401 full_history.extend(prompt);
402
403 let request = if completion_request.tools.is_empty() {
404 json!({
405 "model": self.model,
406 "messages": full_history,
407 "temperature": completion_request.temperature,
408 })
409 } else {
410 json!({
411 "model": self.model,
412 "messages": full_history,
413 "temperature": completion_request.temperature,
414 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
415 "tool_choice": "auto",
416 })
417 };
418
419 let response = self
420 .client
421 .post("/chat/completions")
422 .json(
423 &if let Some(params) = completion_request.additional_params {
424 json_utils::merge(request, params)
425 } else {
426 request
427 },
428 )
429 .send()
430 .await?;
431
432 if response.status().is_success() {
433 let t = response.text().await?;
434 tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
435
436 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
437 ApiResponse::Ok(response) => response.try_into(),
438 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
439 }
440 } else {
441 Err(CompletionError::ProviderError(response.text().await?))
442 }
443 }
444}
445
446pub const DEEPSEEK_CHAT: &str = "deepseek-chat";
452pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
454
455#[cfg(test)]
457mod tests {
458
459 use super::*;
460
461 #[test]
462 fn test_deserialize_vec_choice() {
463 let data = r#"[{
464 "finish_reason": "stop",
465 "index": 0,
466 "logprobs": null,
467 "message":{"role":"assistant","content":"Hello, world!"}
468 }]"#;
469
470 let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
471 assert_eq!(choices.len(), 1);
472 match &choices.first().unwrap().message {
473 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
474 _ => panic!("Expected assistant message"),
475 }
476 }
477
478 #[test]
479 fn test_deserialize_deepseek_response() {
480 let data = r#"{"choices":[{
481 "finish_reason": "stop",
482 "index": 0,
483 "logprobs": null,
484 "message":{"role":"assistant","content":"Hello, world!"}
485 }]}"#;
486
487 let jd = &mut serde_json::Deserializer::from_str(data);
488 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
489 match result {
490 Ok(response) => match &response.choices.first().unwrap().message {
491 Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
492 _ => panic!("Expected assistant message"),
493 },
494 Err(err) => {
495 panic!("Deserialization error at {}: {}", err.path(), err);
496 }
497 }
498 }
499
500 #[test]
501 fn test_deserialize_example_response() {
502 let data = r#"
503 {
504 "id": "e45f6c68-9d9e-43de-beb4-4f402b850feb",
505 "object": "chat.completion",
506 "created": 0,
507 "model": "deepseek-chat",
508 "choices": [
509 {
510 "index": 0,
511 "message": {
512 "role": "assistant",
513 "content": "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
514 },
515 "logprobs": null,
516 "finish_reason": "stop"
517 }
518 ],
519 "usage": {
520 "prompt_tokens": 13,
521 "completion_tokens": 32,
522 "total_tokens": 45,
523 "prompt_tokens_details": {
524 "cached_tokens": 0
525 },
526 "prompt_cache_hit_tokens": 0,
527 "prompt_cache_miss_tokens": 13
528 },
529 "system_fingerprint": "fp_4b6881f2c5"
530 }
531 "#;
532 let jd = &mut serde_json::Deserializer::from_str(data);
533 let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
534
535 match result {
536 Ok(response) => match &response.choices.first().unwrap().message {
537 Message::Assistant { content, .. } => assert_eq!(
538 content,
539 "Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
540 ),
541 _ => panic!("Expected assistant message"),
542 },
543 Err(err) => {
544 panic!("Deserialization error at {}: {}", err.path(), err);
545 }
546 }
547 }
548
549 #[test]
550 fn test_serialize_deserialize_tool_call_message() {
551 let tool_call_choice_json = r#"
552 {
553 "finish_reason": "tool_calls",
554 "index": 0,
555 "logprobs": null,
556 "message": {
557 "content": "",
558 "role": "assistant",
559 "tool_calls": [
560 {
561 "function": {
562 "arguments": "{\"x\":2,\"y\":5}",
563 "name": "subtract"
564 },
565 "id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
566 "index": 0,
567 "type": "function"
568 }
569 ]
570 }
571 }
572 "#;
573
574 let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();
575
576 let expected_choice: Choice = Choice {
577 finish_reason: "tool_calls".to_string(),
578 index: 0,
579 logprobs: None,
580 message: Message::Assistant {
581 content: "".to_string(),
582 name: None,
583 tool_calls: vec![ToolCall {
584 id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
585 function: Function {
586 name: "subtract".to_string(),
587 arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
588 },
589 index: 0,
590 r#type: ToolType::Function,
591 }],
592 },
593 };
594
595 assert_eq!(choice, expected_choice);
596 }
597}