deepseek_sdk/chat/
client.rs1use crate::DeepSeekRequest;
3use crate::error::DeepSeekError;
4use crate::{api_post, api_request_stream};
5
6use super::{Chat, ChatStream, request::*};
7use futures_util::StreamExt;
8use reqwest::Method;
9use reqwest_eventsource::Event;
10use std::sync::mpsc as std_mpsc;
11use tokio::sync::mpsc;
12pub type ChatStreamItem = Result<ChatStream, DeepSeekError>;
14
15pub struct ChatStreamBlocking {
17 pub rx: std_mpsc::Receiver<ChatStreamItem>,
18}
19
20impl Iterator for ChatStreamBlocking {
21 type Item = ChatStreamItem;
22
23 fn next(&mut self) -> Option<Self::Item> {
24 self.rx.recv().ok()
25 }
26}
27
28impl DeepSeekRequest for ChatRequest {
29 type Response = Chat;
30 type StreamItem = ChatStreamItem;
31 type BlockingStream = ChatStreamBlocking;
32
33 async fn send(self) -> Result<Chat, DeepSeekError> {
34 let client = self.client.clone();
35 api_post("/chat/completions", &self, client).await
36 }
37
38 async fn stream(self) -> Result<mpsc::Receiver<ChatStreamItem>, DeepSeekError> {
39 let mut request = self;
40 request.stream = Some(true);
41
42 let client = request.client.clone();
43 let mut event_source = api_request_stream(
44 Method::POST,
45 "/chat/completions",
46 |builder| builder.json(&request),
47 client,
48 )
49 .await?;
50
51 let (tx, rx) = mpsc::channel(32);
52
53 tokio::spawn(async move {
54 while let Some(event) = event_source.next().await {
55 match event {
56 Ok(Event::Open) => {}
57 Ok(Event::Message(message)) => {
58 if message.data == "[DONE]" {
59 break;
60 }
61 match serde_json::from_str::<ChatStream>(&message.data) {
62 Ok(chunk) => {
63 if tx.send(Ok(chunk)).await.is_err() {
64 break;
65 }
66 }
67 Err(err) => {
68 let _ = tx
69 .send(Err(DeepSeekError::decode(err.to_string(), message.data)))
70 .await;
71 break;
72 }
73 }
74 }
75 Err(err) => {
76 let _ = tx
77 .send(Err(DeepSeekError::decode(err.to_string(), String::new())))
78 .await;
79 break;
80 }
81 }
82 }
83 });
84
85 Ok(rx)
86 }
87
88 fn stream_blocking(self) -> Result<ChatStreamBlocking, DeepSeekError> {
89 let (tx, rx) = std_mpsc::channel();
90
91 std::thread::spawn(move || {
92 let runtime = match tokio::runtime::Builder::new_current_thread()
93 .enable_all()
94 .build()
95 {
96 Ok(runtime) => runtime,
97 Err(err) => {
98 let _ = tx.send(Err(DeepSeekError::decode(err.to_string(), String::new())));
99 return;
100 }
101 };
102
103 runtime.block_on(async move {
104 match self.stream().await {
105 Ok(mut stream_rx) => {
106 while let Some(item) = stream_rx.recv().await {
107 if tx.send(item).is_err() {
108 break;
109 }
110 }
111 }
112 Err(err) => {
113 let _ = tx.send(Err(err));
114 }
115 }
116 });
117 });
118
119 Ok(ChatStreamBlocking { rx })
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use crate::{DEFAULT_BASE_URL, DeepSeekClient};
127
128 fn get_client() -> DeepSeekClient {
129 DeepSeekClient::new(
130 std::env::var("DEEPSEEK_API").expect("DEEPSEEK_API is not set"),
131 DEFAULT_BASE_URL.clone(),
132 )
133 }
134
135 fn get_builder() -> ChatRequestBuilder {
136 ChatRequestBuilder::default()
137 .client(get_client())
138 .model("deepseek-v4-flash")
139 .thinking(Thinking::disabled())
140 }
141
142 #[tokio::test]
143 async fn chat() {
144 let req = get_builder()
145 .message(ChatMessage::User {
146 content: "Hi".to_string(),
147 name: None,
148 })
149 .max_tokens(5_u32)
150 .logprobs(true)
151 .top_logprobs(2_u32)
152 .build()
153 .unwrap();
154 let response = req.send().await.unwrap();
155 println!("{:#?}", response);
156 }
157
158 #[tokio::test]
159 async fn api_error() {
160 let req = get_builder()
162 .model("invalid-model-name")
163 .message(ChatMessage::User {
164 content: "Hi".to_string(),
165 name: None,
166 })
167 .build()
168 .unwrap();
169 let response = req.send().await;
170 assert!(response.is_err());
171 if let Err(err) = response {
172 assert!(matches!(err, DeepSeekError::Api { .. }));
173 if let DeepSeekError::Api {
174 error,
175 status,
176 body,
177 } = err
178 {
179 assert_eq!(status, Some(400));
180 assert!(body.is_some());
181 assert_eq!(error.error_type, "invalid_request_error");
182 assert_eq!(error.code.as_deref(), Some("invalid_request_error"));
183 } else {
184 panic!("Expected DeepSeekError::Api");
185 }
186 }
187 }
188
189 #[tokio::test]
190 async fn chat_tool_call() {
191 let mut messages = vec![ChatMessage::User {
192 content: "How's the weather in Hangzhou, Zhejiang?".to_string(),
193 name: None,
194 }];
195 let req_tool = Tool::new(
196 "get_weather",
197 "Get weather of a location, the user should supply a location first.",
198 Some(serde_json::json!({
199 "type": "object",
200 "properties": {
201 "location": {
202 "type": "string",
203 "description": "The city and state, e.g. San Francisco, CA"
204 },
205 },
206 "required": ["location"]
207 })),
208 );
209 let req = get_builder()
210 .tool(req_tool.clone())
211 .messages(messages.clone())
212 .build()
213 .unwrap();
214 let message = req.send().await.unwrap().choices[0].clone().message;
215 let Some(tool_calls) = message.tool_calls.clone() else {
216 return;
217 };
218 let tool_call = tool_calls[0].clone();
219 messages.push(ChatMessage::Assistant {
220 content: message.content,
221 name: None,
222 tool_calls: Some(tool_calls),
223 });
224 messages.push(ChatMessage::Tool {
225 tool_call_id: tool_call.id,
226 content: "24°C".to_string(),
227 });
228
229 let req2 = get_builder()
230 .tool(req_tool)
231 .messages(messages)
232 .build()
233 .unwrap();
234 let response = req2.send().await.unwrap();
235 println!("{:#?}", response);
236 assert!(
237 response.choices[0]
238 .message
239 .content
240 .as_ref()
241 .unwrap()
242 .contains("24°C")
243 );
244 }
245
246 #[tokio::test]
247 async fn chat_stream_async() {
248 let req = get_builder()
249 .message(ChatMessage::User {
250 content: "Hi".to_string(),
251 name: None,
252 })
253 .max_tokens(16_u32)
254 .build()
255 .unwrap();
256
257 let mut rx = req.stream().await.unwrap();
258 while let Some(item) = rx.recv().await {
259 match item {
260 Ok(chunk) => println!("Model>\t {:#?}", chunk),
261 Err(err) => eprintln!("Error>\t {:#?}", err),
262 }
263 }
264 }
265
266 #[test]
267 fn chat_stream_blocking() {
268 let req = get_builder()
269 .message(ChatMessage::User {
270 content: "Hi".to_string(),
271 name: None,
272 })
273 .max_tokens(16_u32)
274 .build()
275 .unwrap();
276
277 let mut stream = req.stream_blocking().unwrap();
278 let mut content = String::new();
279
280 for item in stream.by_ref().take(50) {
281 let chunk = item.unwrap();
282 for choice in chunk.choices {
283 if let Some(delta_content) = choice.delta.content {
284 content.push_str(&delta_content);
285 }
286 }
287 }
288
289 println!("Model>\t {}", content);
290 assert!(!content.is_empty());
291 }
292}