Skip to main content

deepseek_sdk/chat/
client.rs

1//! Chat client implementation for `/chat/completions`.
2use crate::DeepSeekRequest;
3use crate::error::DeepSeekError;
4use crate::{api_post, api_request_stream};
5
6use super::{Chat, ChatStream, request::*};
7use futures_util::StreamExt;
8use reqwest::Method;
9use reqwest_eventsource::Event;
10use std::sync::mpsc as std_mpsc;
11use tokio::sync::mpsc;
12/// Stream item produced by chat streaming.
13pub type ChatStreamItem = Result<ChatStream, DeepSeekError>;
14
15/// Blocking iterator over streaming chat chunks.
16pub struct ChatStreamBlocking {
17    pub rx: std_mpsc::Receiver<ChatStreamItem>,
18}
19
20impl Iterator for ChatStreamBlocking {
21    type Item = ChatStreamItem;
22
23    fn next(&mut self) -> Option<Self::Item> {
24        self.rx.recv().ok()
25    }
26}
27
28impl DeepSeekRequest for ChatRequest {
29    type Response = Chat;
30    type StreamItem = ChatStreamItem;
31    type BlockingStream = ChatStreamBlocking;
32
33    async fn send(self) -> Result<Chat, DeepSeekError> {
34        let client = self.client.clone();
35        api_post("/chat/completions", &self, client).await
36    }
37
38    async fn stream(self) -> Result<mpsc::Receiver<ChatStreamItem>, DeepSeekError> {
39        let mut request = self;
40        request.stream = Some(true);
41
42        let client = request.client.clone();
43        let mut event_source = api_request_stream(
44            Method::POST,
45            "/chat/completions",
46            |builder| builder.json(&request),
47            client,
48        )
49        .await?;
50
51        let (tx, rx) = mpsc::channel(32);
52
53        tokio::spawn(async move {
54            while let Some(event) = event_source.next().await {
55                match event {
56                    Ok(Event::Open) => {}
57                    Ok(Event::Message(message)) => {
58                        if message.data == "[DONE]" {
59                            break;
60                        }
61                        match serde_json::from_str::<ChatStream>(&message.data) {
62                            Ok(chunk) => {
63                                if tx.send(Ok(chunk)).await.is_err() {
64                                    break;
65                                }
66                            }
67                            Err(err) => {
68                                let _ = tx
69                                    .send(Err(DeepSeekError::decode(err.to_string(), message.data)))
70                                    .await;
71                                break;
72                            }
73                        }
74                    }
75                    Err(err) => {
76                        let _ = tx
77                            .send(Err(DeepSeekError::decode(err.to_string(), String::new())))
78                            .await;
79                        break;
80                    }
81                }
82            }
83        });
84
85        Ok(rx)
86    }
87
88    fn stream_blocking(self) -> Result<ChatStreamBlocking, DeepSeekError> {
89        let (tx, rx) = std_mpsc::channel();
90
91        std::thread::spawn(move || {
92            let runtime = match tokio::runtime::Builder::new_current_thread()
93                .enable_all()
94                .build()
95            {
96                Ok(runtime) => runtime,
97                Err(err) => {
98                    let _ = tx.send(Err(DeepSeekError::decode(err.to_string(), String::new())));
99                    return;
100                }
101            };
102
103            runtime.block_on(async move {
104                match self.stream().await {
105                    Ok(mut stream_rx) => {
106                        while let Some(item) = stream_rx.recv().await {
107                            if tx.send(item).is_err() {
108                                break;
109                            }
110                        }
111                    }
112                    Err(err) => {
113                        let _ = tx.send(Err(err));
114                    }
115                }
116            });
117        });
118
119        Ok(ChatStreamBlocking { rx })
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::{DEFAULT_BASE_URL, DeepSeekClient};
127
128    fn get_client() -> DeepSeekClient {
129        DeepSeekClient::new(
130            std::env::var("DEEPSEEK_API").expect("DEEPSEEK_API is not set"),
131            DEFAULT_BASE_URL.clone(),
132        )
133    }
134
135    fn get_builder() -> ChatRequestBuilder {
136        ChatRequestBuilder::default()
137            .client(get_client())
138            .model("deepseek-v4-flash")
139            .thinking(Thinking::disabled())
140    }
141
142    #[tokio::test]
143    async fn chat() {
144        let req = get_builder()
145            .message(ChatMessage::User {
146                content: "Hi".to_string(),
147                name: None,
148            })
149            .max_tokens(5_u32)
150            .logprobs(true)
151            .top_logprobs(2_u32)
152            .build()
153            .unwrap();
154        let response = req.send().await.unwrap();
155        println!("{:#?}", response);
156    }
157
158    #[tokio::test]
159    async fn api_error() {
160        let mut req = get_builder()
161            .message(ChatMessage::User {
162                content: "Hi".to_string(),
163                name: None,
164            })
165            .build()
166            .unwrap();
167        req.reasoning_effort = Some(ReasoningEffort::Max);
168        let response = req.send().await;
169        assert!(response.is_err());
170        if let Err(err) = response {
171            assert!(matches!(err, DeepSeekError::Api { .. }));
172            if let DeepSeekError::Api {
173                error,
174                status,
175                body,
176            } = err
177            {
178                assert_eq!(status, Some(400));
179                assert!(body.is_some());
180                assert_eq!(
181                    error.message,
182                    "thinking options type cannot be disabled when reasoning_effort is set"
183                );
184                assert_eq!(error.error_type, "invalid_request_error");
185                assert_eq!(error.param.as_deref(), None);
186                assert_eq!(error.code.as_deref(), Some("invalid_request_error"));
187            } else {
188                panic!("Expected DeepSeekError::Api");
189            }
190        }
191    }
192
193    #[tokio::test]
194    async fn chat_tool_call() {
195        let mut messages = vec![ChatMessage::User {
196            content: "How's the weather in Hangzhou, Zhejiang?".to_string(),
197            name: None,
198        }];
199        let req_tool = Tool::new(
200            "get_weather",
201            "Get weather of a location, the user should supply a location first.",
202            Some(serde_json::json!({
203                "type": "object",
204                "properties": {
205                    "location": {
206                        "type": "string",
207                        "description": "The city and state, e.g. San Francisco, CA"
208                    },
209                },
210                "required": ["location"]
211            })),
212        );
213        let req = get_builder()
214            .tool(req_tool.clone())
215            .messages(messages.clone())
216            .build()
217            .unwrap();
218        let message = req.send().await.unwrap().choices[0].clone().message;
219        let Some(tool_calls) = message.tool_calls.clone() else {
220            return;
221        };
222        let tool_call = tool_calls[0].clone();
223        messages.push(ChatMessage::Assistant {
224            content: message.content,
225            name: None,
226            tool_calls: Some(tool_calls),
227        });
228        messages.push(ChatMessage::Tool {
229            tool_call_id: tool_call.id,
230            content: "24°C".to_string(),
231        });
232
233        let req2 = get_builder()
234            .tool(req_tool)
235            .messages(messages)
236            .build()
237            .unwrap();
238        let response = req2.send().await.unwrap();
239        println!("{:#?}", response);
240        assert!(
241            response.choices[0]
242                .message
243                .content
244                .as_ref()
245                .unwrap()
246                .contains("24°C")
247        );
248    }
249
250    #[tokio::test]
251    async fn chat_stream_async() {
252        let req = get_builder()
253            .message(ChatMessage::User {
254                content: "Hi".to_string(),
255                name: None,
256            })
257            .max_tokens(16_u32)
258            .build()
259            .unwrap();
260
261        let mut rx = req.stream().await.unwrap();
262        while let Some(item) = rx.recv().await {
263            match item {
264                Ok(chunk) => println!("Model>\t {:#?}", chunk),
265                Err(err) => eprintln!("Error>\t {:#?}", err),
266            }
267        }
268    }
269
270    #[test]
271    fn chat_stream_blocking() {
272        let req = get_builder()
273            .message(ChatMessage::User {
274                content: "Hi".to_string(),
275                name: None,
276            })
277            .max_tokens(16_u32)
278            .build()
279            .unwrap();
280
281        let mut stream = req.stream_blocking().unwrap();
282        let mut content = String::new();
283
284        for item in stream.by_ref().take(50) {
285            let chunk = item.unwrap();
286            for choice in chunk.choices {
287                if let Some(delta_content) = choice.delta.content {
288                    content.push_str(&delta_content);
289                }
290            }
291        }
292
293        println!("Model>\t {}", content);
294        assert!(!content.is_empty());
295    }
296}