Skip to main content

deepseek_sdk/chat/
client.rs

1//! Chat client implementation for `/chat/completions`.
2use crate::DeepSeekRequest;
3use crate::error::DeepSeekError;
4use crate::{api_post, api_request_stream};
5
6use super::{Chat, ChatStream, request::*};
7use futures_util::StreamExt;
8use reqwest::Method;
9use reqwest_eventsource::Event;
10use std::sync::mpsc as std_mpsc;
11use tokio::sync::mpsc;
12/// Stream item produced by chat streaming.
13pub type ChatStreamItem = Result<ChatStream, DeepSeekError>;
14
15/// Blocking iterator over streaming chat chunks.
16pub struct ChatStreamBlocking {
17    pub rx: std_mpsc::Receiver<ChatStreamItem>,
18}
19
20impl Iterator for ChatStreamBlocking {
21    type Item = ChatStreamItem;
22
23    fn next(&mut self) -> Option<Self::Item> {
24        self.rx.recv().ok()
25    }
26}
27
28impl DeepSeekRequest for ChatRequest {
29    type Response = Chat;
30    type StreamItem = ChatStreamItem;
31    type BlockingStream = ChatStreamBlocking;
32
33    async fn send(self) -> Result<Chat, DeepSeekError> {
34        let client = self.client.clone();
35        api_post("/chat/completions", &self, client).await
36    }
37
38    async fn stream(self) -> Result<mpsc::Receiver<ChatStreamItem>, DeepSeekError> {
39        let mut request = self;
40        request.stream = Some(true);
41
42        let client = request.client.clone();
43        let mut event_source = api_request_stream(
44            Method::POST,
45            "/chat/completions",
46            |builder| builder.json(&request),
47            client,
48        )
49        .await?;
50
51        let (tx, rx) = mpsc::channel(32);
52
53        tokio::spawn(async move {
54            while let Some(event) = event_source.next().await {
55                match event {
56                    Ok(Event::Open) => {}
57                    Ok(Event::Message(message)) => {
58                        if message.data == "[DONE]" {
59                            break;
60                        }
61                        match serde_json::from_str::<ChatStream>(&message.data) {
62                            Ok(chunk) => {
63                                if tx.send(Ok(chunk)).await.is_err() {
64                                    break;
65                                }
66                            }
67                            Err(err) => {
68                                let _ = tx
69                                    .send(Err(DeepSeekError::decode(err.to_string(), message.data)))
70                                    .await;
71                                break;
72                            }
73                        }
74                    }
75                    Err(err) => {
76                        let _ = tx
77                            .send(Err(DeepSeekError::decode(err.to_string(), String::new())))
78                            .await;
79                        break;
80                    }
81                }
82            }
83        });
84
85        Ok(rx)
86    }
87
88    fn stream_blocking(self) -> Result<ChatStreamBlocking, DeepSeekError> {
89        let (tx, rx) = std_mpsc::channel();
90
91        std::thread::spawn(move || {
92            let runtime = match tokio::runtime::Builder::new_current_thread()
93                .enable_all()
94                .build()
95            {
96                Ok(runtime) => runtime,
97                Err(err) => {
98                    let _ = tx.send(Err(DeepSeekError::decode(err.to_string(), String::new())));
99                    return;
100                }
101            };
102
103            runtime.block_on(async move {
104                match self.stream().await {
105                    Ok(mut stream_rx) => {
106                        while let Some(item) = stream_rx.recv().await {
107                            if tx.send(item).is_err() {
108                                break;
109                            }
110                        }
111                    }
112                    Err(err) => {
113                        let _ = tx.send(Err(err));
114                    }
115                }
116            });
117        });
118
119        Ok(ChatStreamBlocking { rx })
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::{DEFAULT_BASE_URL, DeepSeekClient};
127
128    fn get_client() -> DeepSeekClient {
129        DeepSeekClient::new(
130            std::env::var("DEEPSEEK_API").expect("DEEPSEEK_API is not set"),
131            DEFAULT_BASE_URL.clone(),
132        )
133    }
134
135    fn get_builder() -> ChatRequestBuilder {
136        ChatRequestBuilder::default()
137            .client(get_client())
138            .model("deepseek-v4-flash")
139            .thinking(Thinking::disabled())
140    }
141
142    #[tokio::test]
143    async fn chat() {
144        let req = get_builder()
145            .message(ChatMessage::User {
146                content: "Hi".to_string(),
147                name: None,
148            })
149            .max_tokens(5_u32)
150            .logprobs(true)
151            .top_logprobs(2_u32)
152            .build()
153            .unwrap();
154        let response = req.send().await.unwrap();
155        println!("{:#?}", response);
156    }
157
158    #[tokio::test]
159    async fn api_error() {
160        // Send a request with a deliberately invalid model to verify API error handling
161        let req = get_builder()
162            .model("invalid-model-name")
163            .message(ChatMessage::User {
164                content: "Hi".to_string(),
165                name: None,
166            })
167            .build()
168            .unwrap();
169        let response = req.send().await;
170        assert!(response.is_err());
171        if let Err(err) = response {
172            assert!(matches!(err, DeepSeekError::Api { .. }));
173            if let DeepSeekError::Api {
174                error,
175                status,
176                body,
177            } = err
178            {
179                assert_eq!(status, Some(400));
180                assert!(body.is_some());
181                assert_eq!(error.error_type, "invalid_request_error");
182                assert_eq!(error.code.as_deref(), Some("invalid_request_error"));
183            } else {
184                panic!("Expected DeepSeekError::Api");
185            }
186        }
187    }
188
189    #[tokio::test]
190    async fn chat_tool_call() {
191        let mut messages = vec![ChatMessage::User {
192            content: "How's the weather in Hangzhou, Zhejiang?".to_string(),
193            name: None,
194        }];
195        let req_tool = Tool::new(
196            "get_weather",
197            "Get weather of a location, the user should supply a location first.",
198            Some(serde_json::json!({
199                "type": "object",
200                "properties": {
201                    "location": {
202                        "type": "string",
203                        "description": "The city and state, e.g. San Francisco, CA"
204                    },
205                },
206                "required": ["location"]
207            })),
208        );
209        let req = get_builder()
210            .tool(req_tool.clone())
211            .messages(messages.clone())
212            .build()
213            .unwrap();
214        let message = req.send().await.unwrap().choices[0].clone().message;
215        let Some(tool_calls) = message.tool_calls.clone() else {
216            return;
217        };
218        let tool_call = tool_calls[0].clone();
219        messages.push(ChatMessage::Assistant {
220            content: message.content,
221            name: None,
222            tool_calls: Some(tool_calls),
223        });
224        messages.push(ChatMessage::Tool {
225            tool_call_id: tool_call.id,
226            content: "24°C".to_string(),
227        });
228
229        let req2 = get_builder()
230            .tool(req_tool)
231            .messages(messages)
232            .build()
233            .unwrap();
234        let response = req2.send().await.unwrap();
235        println!("{:#?}", response);
236        assert!(
237            response.choices[0]
238                .message
239                .content
240                .as_ref()
241                .unwrap()
242                .contains("24°C")
243        );
244    }
245
246    #[tokio::test]
247    async fn chat_stream_async() {
248        let req = get_builder()
249            .message(ChatMessage::User {
250                content: "Hi".to_string(),
251                name: None,
252            })
253            .max_tokens(16_u32)
254            .build()
255            .unwrap();
256
257        let mut rx = req.stream().await.unwrap();
258        while let Some(item) = rx.recv().await {
259            match item {
260                Ok(chunk) => println!("Model>\t {:#?}", chunk),
261                Err(err) => eprintln!("Error>\t {:#?}", err),
262            }
263        }
264    }
265
266    #[test]
267    fn chat_stream_blocking() {
268        let req = get_builder()
269            .message(ChatMessage::User {
270                content: "Hi".to_string(),
271                name: None,
272            })
273            .max_tokens(16_u32)
274            .build()
275            .unwrap();
276
277        let mut stream = req.stream_blocking().unwrap();
278        let mut content = String::new();
279
280        for item in stream.by_ref().take(50) {
281            let chunk = item.unwrap();
282            for choice in chunk.choices {
283                if let Some(delta_content) = choice.delta.content {
284                    content.push_str(&delta_content);
285                }
286            }
287        }
288
289        println!("Model>\t {}", content);
290        assert!(!content.is_empty());
291    }
292}