deepseek_sdk/chat/
client.rs1use crate::DeepSeekRequest;
3use crate::error::DeepSeekError;
4use crate::{api_post, api_request_stream};
5
6use super::{Chat, ChatStream, request::*};
7use futures_util::StreamExt;
8use reqwest::Method;
9use reqwest_eventsource::Event;
10use std::sync::mpsc as std_mpsc;
11use tokio::sync::mpsc;
12pub type ChatStreamItem = Result<ChatStream, DeepSeekError>;
14
15pub struct ChatStreamBlocking {
17 pub rx: std_mpsc::Receiver<ChatStreamItem>,
18}
19
20impl Iterator for ChatStreamBlocking {
21 type Item = ChatStreamItem;
22
23 fn next(&mut self) -> Option<Self::Item> {
24 self.rx.recv().ok()
25 }
26}
27
28impl DeepSeekRequest for ChatRequest {
29 type Response = Chat;
30 type StreamItem = ChatStreamItem;
31 type BlockingStream = ChatStreamBlocking;
32
33 async fn send(self) -> Result<Chat, DeepSeekError> {
34 let client = self.client.clone();
35 api_post("/chat/completions", &self, client).await
36 }
37
38 async fn stream(self) -> Result<mpsc::Receiver<ChatStreamItem>, DeepSeekError> {
39 let mut request = self;
40 request.stream = Some(true);
41
42 let client = request.client.clone();
43 let mut event_source = api_request_stream(
44 Method::POST,
45 "/chat/completions",
46 |builder| builder.json(&request),
47 client,
48 )
49 .await?;
50
51 let (tx, rx) = mpsc::channel(32);
52
53 tokio::spawn(async move {
54 while let Some(event) = event_source.next().await {
55 match event {
56 Ok(Event::Open) => {}
57 Ok(Event::Message(message)) => {
58 if message.data == "[DONE]" {
59 break;
60 }
61 match serde_json::from_str::<ChatStream>(&message.data) {
62 Ok(chunk) => {
63 if tx.send(Ok(chunk)).await.is_err() {
64 break;
65 }
66 }
67 Err(err) => {
68 let _ = tx
69 .send(Err(DeepSeekError::decode(err.to_string(), message.data)))
70 .await;
71 break;
72 }
73 }
74 }
75 Err(err) => {
76 let _ = tx
77 .send(Err(DeepSeekError::decode(err.to_string(), String::new())))
78 .await;
79 break;
80 }
81 }
82 }
83 });
84
85 Ok(rx)
86 }
87
88 fn stream_blocking(self) -> Result<ChatStreamBlocking, DeepSeekError> {
89 let (tx, rx) = std_mpsc::channel();
90
91 std::thread::spawn(move || {
92 let runtime = match tokio::runtime::Builder::new_current_thread()
93 .enable_all()
94 .build()
95 {
96 Ok(runtime) => runtime,
97 Err(err) => {
98 let _ = tx.send(Err(DeepSeekError::decode(err.to_string(), String::new())));
99 return;
100 }
101 };
102
103 runtime.block_on(async move {
104 match self.stream().await {
105 Ok(mut stream_rx) => {
106 while let Some(item) = stream_rx.recv().await {
107 if tx.send(item).is_err() {
108 break;
109 }
110 }
111 }
112 Err(err) => {
113 let _ = tx.send(Err(err));
114 }
115 }
116 });
117 });
118
119 Ok(ChatStreamBlocking { rx })
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use crate::{DEFAULT_BASE_URL, DeepSeekClient};
127
128 fn get_client() -> DeepSeekClient {
129 DeepSeekClient::new(
130 std::env::var("DEEPSEEK_API").expect("DEEPSEEK_API is not set"),
131 DEFAULT_BASE_URL.clone(),
132 )
133 }
134
135 fn get_builder() -> ChatRequestBuilder {
136 ChatRequestBuilder::default()
137 .client(get_client())
138 .model("deepseek-v4-flash")
139 .thinking(Thinking::disabled())
140 }
141
142 #[tokio::test]
143 async fn chat() {
144 let req = get_builder()
145 .message(ChatMessage::User {
146 content: "Hi".to_string(),
147 name: None,
148 })
149 .max_tokens(5_u32)
150 .logprobs(true)
151 .top_logprobs(2_u32)
152 .build()
153 .unwrap();
154 let response = req.send().await.unwrap();
155 println!("{:#?}", response);
156 }
157
158 #[tokio::test]
159 async fn api_error() {
160 let mut req = get_builder()
161 .message(ChatMessage::User {
162 content: "Hi".to_string(),
163 name: None,
164 })
165 .build()
166 .unwrap();
167 req.reasoning_effort = Some(ReasoningEffort::Max);
168 let response = req.send().await;
169 assert!(response.is_err());
170 if let Err(err) = response {
171 assert!(matches!(err, DeepSeekError::Api { .. }));
172 if let DeepSeekError::Api {
173 error,
174 status,
175 body,
176 } = err
177 {
178 assert_eq!(status, Some(400));
179 assert!(body.is_some());
180 assert_eq!(
181 error.message,
182 "thinking options type cannot be disabled when reasoning_effort is set"
183 );
184 assert_eq!(error.error_type, "invalid_request_error");
185 assert_eq!(error.param.as_deref(), None);
186 assert_eq!(error.code.as_deref(), Some("invalid_request_error"));
187 } else {
188 panic!("Expected DeepSeekError::Api");
189 }
190 }
191 }
192
193 #[tokio::test]
194 async fn chat_tool_call() {
195 let mut messages = vec![ChatMessage::User {
196 content: "How's the weather in Hangzhou, Zhejiang?".to_string(),
197 name: None,
198 }];
199 let req_tool = Tool::new(
200 "get_weather",
201 "Get weather of a location, the user should supply a location first.",
202 Some(serde_json::json!({
203 "type": "object",
204 "properties": {
205 "location": {
206 "type": "string",
207 "description": "The city and state, e.g. San Francisco, CA"
208 },
209 },
210 "required": ["location"]
211 })),
212 );
213 let req = get_builder()
214 .tool(req_tool.clone())
215 .messages(messages.clone())
216 .build()
217 .unwrap();
218 let message = req.send().await.unwrap().choices[0].clone().message;
219 let Some(tool_calls) = message.tool_calls.clone() else {
220 return;
221 };
222 let tool_call = tool_calls[0].clone();
223 messages.push(ChatMessage::Assistant {
224 content: message.content,
225 name: None,
226 tool_calls: Some(tool_calls),
227 });
228 messages.push(ChatMessage::Tool {
229 tool_call_id: tool_call.id,
230 content: "24°C".to_string(),
231 });
232
233 let req2 = get_builder()
234 .tool(req_tool)
235 .messages(messages)
236 .build()
237 .unwrap();
238 let response = req2.send().await.unwrap();
239 println!("{:#?}", response);
240 assert!(
241 response.choices[0]
242 .message
243 .content
244 .as_ref()
245 .unwrap()
246 .contains("24°C")
247 );
248 }
249
250 #[tokio::test]
251 async fn chat_stream_async() {
252 let req = get_builder()
253 .message(ChatMessage::User {
254 content: "Hi".to_string(),
255 name: None,
256 })
257 .max_tokens(16_u32)
258 .build()
259 .unwrap();
260
261 let mut rx = req.stream().await.unwrap();
262 while let Some(item) = rx.recv().await {
263 match item {
264 Ok(chunk) => println!("Model>\t {:#?}", chunk),
265 Err(err) => eprintln!("Error>\t {:#?}", err),
266 }
267 }
268 }
269
270 #[test]
271 fn chat_stream_blocking() {
272 let req = get_builder()
273 .message(ChatMessage::User {
274 content: "Hi".to_string(),
275 name: None,
276 })
277 .max_tokens(16_u32)
278 .build()
279 .unwrap();
280
281 let mut stream = req.stream_blocking().unwrap();
282 let mut content = String::new();
283
284 for item in stream.by_ref().take(50) {
285 let chunk = item.unwrap();
286 for choice in chunk.choices {
287 if let Some(delta_content) = choice.delta.content {
288 content.push_str(&delta_content);
289 }
290 }
291 }
292
293 println!("Model>\t {}", content);
294 assert!(!content.is_empty());
295 }
296}