1use crate::credential::Credential;
4use crate::error::{Error, Result};
5use crate::handlers::callback::CallbackHandler;
6use crate::handlers::chatbot::{AsyncChatbotHandler, ChatbotReplier, async_raw_process};
7use crate::handlers::event::EventHandler;
8use crate::handlers::system::{DefaultSystemHandler, SystemHandler};
9use crate::messages::frames::{AckMessage, StreamMessage, SystemMessage};
10use crate::transport::http::HttpClient;
11use crate::transport::token::TokenManager;
12use futures_util::{SinkExt, StreamExt};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio_tungstenite::connect_async;
16use tokio_tungstenite::tungstenite::Message;
17use url::form_urlencoded;
18
19enum CallbackEntry {
21 Sync(Arc<dyn CallbackHandler>),
22 Async(Arc<dyn AsyncChatbotHandler>),
23}
24
25pub struct DingTalkStreamClient {
27 credential: Credential,
28 event_handler: Option<Arc<dyn EventHandler>>,
29 callback_handlers: HashMap<String, CallbackEntry>,
30 system_handler: Arc<dyn SystemHandler>,
31 http_client: HttpClient,
32 token_manager: Arc<TokenManager>,
33 is_event_required: bool,
34 pre_started: bool,
35}
36
37impl DingTalkStreamClient {
38 pub fn builder(credential: Credential) -> ClientBuilder {
40 ClientBuilder::new(credential)
41 }
42
43 pub async fn get_access_token(&self) -> Result<String> {
45 self.token_manager.get_access_token().await
46 }
47
48 pub async fn reset_access_token(&self) {
50 self.token_manager.reset().await;
51 }
52
53 pub fn chatbot_replier(&self) -> ChatbotReplier {
55 ChatbotReplier::new(
56 self.http_client.clone(),
57 Arc::clone(&self.token_manager),
58 self.credential.client_id.clone(),
59 )
60 }
61
62 pub async fn upload_to_dingtalk(
64 &self,
65 image_content: &[u8],
66 filetype: &str,
67 filename: &str,
68 mimetype: &str,
69 ) -> Result<String> {
70 let access_token = self.token_manager.get_access_token().await?;
71 let result = self
72 .http_client
73 .upload_file(&access_token, image_content, filetype, filename, mimetype)
74 .await;
75
76 if let Err(Error::Auth(_)) = &result {
77 self.token_manager.reset().await;
78 }
79
80 result
81 }
82
83 pub async fn start(&mut self) -> Result<()> {
85 self.pre_start();
86
87 loop {
88 match self.run_once().await {
89 Ok(()) => {
90 tracing::info!("connection closed, reconnecting in 3s...");
91 }
92 Err(e) => {
93 tracing::error!(error = %e, "connection error, reconnecting in 10s...");
94 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
95 continue;
96 }
97 }
98 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
99 }
100 }
101
102 pub fn start_forever(&mut self) -> Result<()> {
104 let rt = tokio::runtime::Runtime::new()
105 .map_err(|e| Error::Connection(format!("failed to create runtime: {e}")))?;
106 rt.block_on(self.start())
107 }
108
109 fn pre_start(&mut self) {
110 if self.pre_started {
111 return;
112 }
113 self.pre_started = true;
114
115 if let Some(ref handler) = self.event_handler {
116 handler.pre_start();
117 }
118 self.system_handler.pre_start();
119 for entry in self.callback_handlers.values() {
120 match entry {
121 CallbackEntry::Sync(h) => h.pre_start(),
122 CallbackEntry::Async(h) => h.pre_start(),
123 }
124 }
125 }
126
127 async fn run_once(&self) -> Result<()> {
129 let connection = self.open_connection().await?;
130
131 let endpoint = connection
132 .get("endpoint")
133 .and_then(|v| v.as_str())
134 .ok_or_else(|| Error::Connection("endpoint not found".to_owned()))?;
135 let ticket = connection
136 .get("ticket")
137 .and_then(|v| v.as_str())
138 .ok_or_else(|| Error::Connection("ticket not found".to_owned()))?;
139
140 let encoded_ticket: String = form_urlencoded::Serializer::new(String::new())
141 .append_pair("ticket", ticket)
142 .finish();
143 let uri = format!("{}?{}", endpoint, encoded_ticket);
144
145 tracing::info!(endpoint = %endpoint, "connecting to WebSocket");
146
147 let (ws_stream, _) = connect_async(&uri).await?;
148 let (write, read) = ws_stream.split();
149
150 let write = Arc::new(tokio::sync::Mutex::new(write));
151 let write_keepalive = Arc::clone(&write);
152
153 let keepalive_handle = tokio::spawn(async move {
155 loop {
156 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
157 let mut w = write_keepalive.lock().await;
158 if w.send(Message::Ping(Vec::new().into())).await.is_err() {
159 break;
160 }
161 }
162 });
163
164 let mut read = read;
166 while let Some(msg_result) = read.next().await {
167 match msg_result {
168 Ok(Message::Text(text)) => {
169 let route_result = self.route_message(&text).await;
170 match route_result {
171 Ok((ack_opt, should_disconnect)) => {
172 if let Some(ack) = ack_opt {
173 let ack_json = serde_json::to_string(&ack).unwrap_or_default();
174 let mut w = write.lock().await;
175 if let Err(e) = w.send(Message::Text(ack_json.into())).await {
176 tracing::error!(error = %e, "failed to send ack");
177 break;
178 }
179 }
180 if should_disconnect {
181 tracing::info!("received disconnect, closing connection");
182 let mut w = write.lock().await;
183 let _ = w.close().await;
184 break;
185 }
186 }
187 Err(e) => {
188 tracing::error!(error = %e, "route message failed");
189 }
190 }
191 }
192 Ok(Message::Pong(_)) => {}
193 Ok(Message::Close(_)) => {
194 tracing::info!("WebSocket closed by server");
195 break;
196 }
197 Err(e) => {
198 tracing::error!(error = %e, "WebSocket read error");
199 break;
200 }
201 _ => {}
202 }
203 }
204
205 keepalive_handle.abort();
206 Ok(())
207 }
208
209 async fn route_message(&self, raw: &str) -> Result<(Option<AckMessage>, bool)> {
211 let msg: StreamMessage = serde_json::from_str(raw)?;
212 let mut should_disconnect = false;
213
214 let ack = match msg {
215 StreamMessage::System(body) => {
216 let ack = self.system_handler.raw_process(&body).await;
217 if body.headers.topic.as_deref() == Some(SystemMessage::TOPIC_DISCONNECT) {
218 should_disconnect = true;
219 tracing::info!(
220 topic = ?body.headers.topic,
221 "received disconnect"
222 );
223 } else {
224 tracing::warn!(
225 topic = ?body.headers.topic,
226 "unknown system message topic"
227 );
228 }
229 Some(ack)
230 }
231 StreamMessage::Event(body) => {
232 if let Some(ref handler) = self.event_handler {
233 Some(handler.raw_process(&body).await)
234 } else {
235 tracing::warn!("no event handler registered");
236 None
237 }
238 }
239 StreamMessage::Callback(body) => {
240 let topic = body.headers.topic.as_deref().unwrap_or("");
241 if let Some(entry) = self.callback_handlers.get(topic) {
242 match entry {
243 CallbackEntry::Sync(handler) => Some(handler.raw_process(&body).await),
244 CallbackEntry::Async(handler) => {
245 Some(async_raw_process(Arc::clone(handler), body).await)
246 }
247 }
248 } else {
249 tracing::warn!(topic = %topic, "unknown callback topic");
250 None
251 }
252 }
253 };
254
255 Ok((ack, should_disconnect))
256 }
257
258 async fn open_connection(&self) -> Result<serde_json::Value> {
260 let url = format!(
261 "{}/v1.0/gateway/connections/open",
262 self.http_client.openapi_endpoint()
263 );
264
265 tracing::info!(url = %url, "opening connection");
266
267 let mut topics: Vec<serde_json::Value> = Vec::new();
268 if self.is_event_required {
269 topics.push(serde_json::json!({"type": "EVENT", "topic": "*"}));
270 }
271 for topic in self.callback_handlers.keys() {
272 topics.push(serde_json::json!({"type": "CALLBACK", "topic": topic}));
273 }
274
275 let body = serde_json::json!({
276 "clientId": self.credential.client_id,
277 "clientSecret": self.credential.client_secret,
278 "subscriptions": topics,
279 "ua": format!("dingtalk-sdk-rust/v{}-union", env!("CARGO_PKG_VERSION")),
280 "localIp": get_host_ip(),
281 });
282
283 self.http_client.post_raw(&url, &body).await
284 }
285}
286
287pub struct ClientBuilder {
289 credential: Credential,
290 event_handler: Option<Arc<dyn EventHandler>>,
291 callback_handlers: HashMap<String, CallbackEntry>,
292 system_handler: Option<Arc<dyn SystemHandler>>,
293}
294
295impl ClientBuilder {
296 pub fn new(credential: Credential) -> Self {
298 Self {
299 credential,
300 event_handler: None,
301 callback_handlers: HashMap::new(),
302 system_handler: None,
303 }
304 }
305
306 pub fn register_all_event_handler(mut self, handler: impl EventHandler + 'static) -> Self {
308 self.event_handler = Some(Arc::new(handler));
309 self
310 }
311
312 pub fn register_callback_handler(
314 mut self,
315 topic: &str,
316 handler: impl CallbackHandler + 'static,
317 ) -> Self {
318 self.callback_handlers
319 .insert(topic.to_owned(), CallbackEntry::Sync(Arc::new(handler)));
320 self
321 }
322
323 pub fn register_async_chatbot_handler(
325 mut self,
326 topic: &str,
327 handler: impl AsyncChatbotHandler + 'static,
328 ) -> Self {
329 self.callback_handlers
330 .insert(topic.to_owned(), CallbackEntry::Async(Arc::new(handler)));
331 self
332 }
333
334 pub fn register_system_handler(mut self, handler: impl SystemHandler + 'static) -> Self {
336 self.system_handler = Some(Arc::new(handler));
337 self
338 }
339
340 pub fn build(self) -> DingTalkStreamClient {
342 let http_client = HttpClient::new();
343 let token_manager = Arc::new(TokenManager::new(
344 self.credential.clone(),
345 http_client.clone(),
346 ));
347
348 let is_event_required = self.event_handler.is_some();
349
350 DingTalkStreamClient {
351 credential: self.credential,
352 event_handler: self.event_handler,
353 callback_handlers: self.callback_handlers,
354 system_handler: self
355 .system_handler
356 .unwrap_or_else(|| Arc::new(DefaultSystemHandler)),
357 http_client,
358 token_manager,
359 is_event_required,
360 pre_started: false,
361 }
362 }
363}
364
365fn get_host_ip() -> String {
367 use std::net::UdpSocket;
368 UdpSocket::bind("0.0.0.0:0")
369 .and_then(|socket| {
370 socket.connect("8.8.8.8:80")?;
371 socket.local_addr()
372 })
373 .map(|addr| addr.ip().to_string())
374 .unwrap_or_default()
375}