Skip to main content

dingtalk_stream/client/stream_/
lifecycle.rs

1use crate::client::{ConnectionResponse, Subscription};
2use crate::frames::down_message::MessageTopic;
3use crate::utils::get_local_ip;
4use crate::{DingTalkStream, GATEWAY_URL};
5use anyhow::anyhow;
6use futures_util::{SinkExt, Stream, StreamExt};
7use std::fmt::Display;
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::sync::mpsc::Receiver;
13use tokio::task::JoinHandle;
14use tokio::time::sleep;
15use tokio_tungstenite::connect_async;
16use tokio_tungstenite::tungstenite::Message;
17use log::{debug, error, info, warn};
18
19impl DingTalkStream {
20    /// Start the client and run forever (auto-reconnect)
21    pub async fn start(
22        self: Arc<Self>,
23    ) -> crate::Result<(Arc<Self>, JoinHandle<crate::Result<()>>)> {
24        info!("Starting DingTalk Stream client...");
25        let self_ = Arc::clone(&self);
26        let join_handle = tokio::spawn(async move {
27            let _ = self_.lifecycle_listener.on_start(Arc::clone(&self_)).await;
28            loop {
29                let result = Arc::clone(&self_).connect().await;
30                let _ = self_
31                    .lifecycle_listener
32                    .on_disconnected(Arc::clone(&self_), &result)
33                    .await;
34                match result {
35                    Ok(_) => {
36                        info!("Connection closed normally");
37                        let _ = self_.lifecycle_listener.on_stopped(Arc::clone(&self_));
38                        return Ok(());
39                    }
40                    Err(e) => {
41                        error!("Connection error: {}", e);
42                        if self_.config.auto_reconnect {
43                            info!(
44                                "Reconnecting in {} seconds...",
45                                self_.config.reconnect_interval.as_secs()
46                            );
47                            sleep(self_.config.reconnect_interval).await;
48                        } else {
49                            let _ = self_.lifecycle_listener.on_stopped(Arc::clone(&self_));
50                            return Err(anyhow!(e));
51                        }
52                    }
53                }
54            }
55        });
56        Ok((self, join_handle))
57    }
58}
59
60impl DingTalkStream {
61    /// Connect to DingTalk WebSocket
62    async fn connect(self: Arc<Self>) -> crate::Result<()> {
63        let connection = self.open_connection().await.map_err(|err| anyhow!(err))?;
64        let ws_url = format!("{}?ticket={}", connection.endpoint, connection.ticket);
65        info!("Connecting to WebSocket: {}", ws_url);
66        let _ = self
67            .lifecycle_listener
68            .on_connecting(Arc::clone(&self), &ws_url)
69            .await;
70        // Connect to WebSocket
71        let (ws_stream, _) = connect_async(&ws_url).await?;
72        let (ws_write, ws_read) = ws_stream.split();
73        self.connected.store(true, Ordering::SeqCst);
74        info!("Connected to DingTalk WebSocket {}", ws_url);
75        let _ = self
76            .lifecycle_listener
77            .on_connected(Arc::clone(&self), &ws_url)
78            .await;
79        let (ws_write_join_handle, ws_read_handle) = {
80            // Create channel for sending messages
81            let (msg_stream_sender, msg_stream_receiver) = mpsc::channel::<String>(256);
82            let ws_write_join_handle = Arc::clone(&self)
83                .ws_write(ws_write, msg_stream_receiver)
84                .await;
85            // Spawn keep-alive task if enabled
86            let _ = Arc::clone(&self).keepalive(msg_stream_sender.clone()).await;
87            let ws_read_handle = Arc::clone(&self).ws_read(ws_read, msg_stream_sender).await;
88            (ws_write_join_handle, ws_read_handle)
89        };
90
91        if let Ok(exit_normally) = ws_read_handle.await {
92            exit_normally?;
93        }
94        if let Ok(exit_normally) = ws_write_join_handle.await {
95            exit_normally?;
96        }
97        self.connected.store(false, Ordering::SeqCst);
98        self.registered.store(false, Ordering::SeqCst);
99        Ok(())
100    }
101
102    async fn ws_write<Sink>(
103        self: Arc<Self>,
104        mut ws_write: Sink,
105        mut msg_stream_receiver: Receiver<String>,
106    ) -> JoinHandle<crate::Result<()>>
107    where
108        Sink: SinkExt<Message> + Unpin + Send + Sync + 'static,
109        <Sink as futures_util::Sink<Message>>::Error: Display + Into<anyhow::Error> + Send + Sync,
110    {
111        tokio::spawn(async move {
112            while let Some(ref msg) = msg_stream_receiver.recv().await {
113                let result = Arc::clone(&self)
114                    .ws_write_with_retry(&mut ws_write, msg)
115                    .await;
116                let _ = self
117                    .lifecycle_listener
118                    .on_websocket_write(Arc::clone(&self), msg, &result)
119                    .await;
120                match result {
121                    Ok(_) => {}
122                    Err(err) => {
123                        return Err(anyhow!(err));
124                    }
125                }
126            }
127            Ok(())
128        })
129    }
130    async fn ws_write_with_retry<W>(
131        self: Arc<Self>,
132        ws_write: &mut W,
133        msg: &str,
134    ) -> crate::Result<()>
135    where
136        W: SinkExt<Message> + Unpin,
137        <W as futures_util::Sink<Message>>::Error: Display + Into<anyhow::Error>,
138    {
139        const TRY_INTERVAL: Duration = Duration::from_secs(1);
140        const TRY_MAX: u8 = 8;
141        info!("Sending message to WebSocket, msg: {}", msg);
142        let mut cnt = 1;
143        loop {
144            let result = ws_write
145                .send(Message::Text(msg.into()))
146                .await
147                .map_err(|err| anyhow!(err));
148            let _ = self
149                .lifecycle_listener
150                .on_websocket_write_with_retry(Arc::clone(&self), msg, cnt, &result)
151                .await;
152            match result {
153                Ok(_) => {
154                    info!("Success to send message to WebSocket, {}", msg);
155                    return Ok(());
156                }
157                Err(err) => {
158                    if {
159                        cnt += 1;
160                        cnt
161                    } > TRY_MAX
162                    {
163                        warn!("Failed to send message to WebSocket, retrying in 1 second, err: {}, msg: {}", err, msg);
164                        tokio::time::sleep(TRY_INTERVAL).await;
165                        continue;
166                    }
167                    error!(
168                        "Failed to send message to WebSocket, after {} retries, err: {}, msg: {}",
169                        err, cnt, msg
170                    );
171                    return Err(err);
172                }
173            }
174        }
175    }
176
177    async fn ws_read<R, E>(
178        self: Arc<Self>,
179        mut ws_read: R,
180        msg_stream_sender: mpsc::Sender<String>,
181    ) -> JoinHandle<crate::Result<()>>
182    where
183        E: Display + Into<anyhow::Error> + Send + Sync,
184        R: Stream<Item = Result<Message, E>> + Unpin + Send + Sync + 'static,
185    {
186        tokio::spawn(async move {
187            while let Some(result) = ws_read.next().await {
188                let result = result.map_err(|err| anyhow!(err));
189                let _ = self
190                    .lifecycle_listener
191                    .on_websocket_read(Arc::clone(&self), &result)
192                    .await;
193                match result {
194                    Ok(Message::Text(text)) => {
195                        info!("Received text message: {}", text);
196                        if let Err(e) = Self::handle_message(
197                            Arc::clone(&self),
198                            &text,
199                            msg_stream_sender.clone(),
200                        )
201                        .await
202                        {
203                            error!("Error handling message: {}", e);
204                        }
205                    }
206                    Ok(Message::Close(_)) => {
207                        warn!("Received close message: WebSocket connection will be closed!!!");
208                        return Ok(());
209                    }
210                    Err(err) => {
211                        error!("WebSocket error: {}", err);
212                        return Err(err);
213                    }
214                    _ => continue,
215                }
216            }
217            unreachable!()
218        })
219    }
220
221    async fn keepalive(self: Arc<Self>, msg_stream_sender: mpsc::Sender<String>) -> JoinHandle<()> {
222        tokio::spawn(async move {
223            loop {
224                tokio::time::sleep(self.config.keep_alive_interval).await;
225                const PING: &str = r#"{"code": 200,"message": "ping"}"#;
226                let result = msg_stream_sender
227                    .send(PING.into())
228                    .await
229                    .map_err(|err| anyhow!(err));
230                let _ = &self
231                    .lifecycle_listener
232                    .on_keepalive(Arc::clone(&self), PING, &result)
233                    .await;
234                match result {
235                    Ok(_) => {
236                        continue;
237                    }
238                    Err(err) => {
239                        warn!("stream_tx dropped error, keepalive task stopping. err: {err}");
240                        return;
241                    }
242                }
243            }
244        })
245    }
246
247    /// Open connection to DingTalk
248    async fn open_connection(
249        &self,
250    ) -> Result<ConnectionResponse, Box<dyn std::error::Error + Send + Sync>> {
251        let subscriptions = self.build_subscriptions()?;
252
253        let client = &self.http_client;
254        let local_ip = get_local_ip().unwrap_or_else(|| "127.0.0.1".to_string());
255
256        let request_body = serde_json::json!({
257            "clientId": self.credential.client_id,
258            "clientSecret": self.credential.client_secret,
259            "subscriptions": subscriptions,
260            "ua": self.config.ua,
261            "localIp": local_ip,
262        });
263
264        info!("Opening connection to {}", GATEWAY_URL);
265        debug!("Request body: {:?}", request_body);
266
267        let response = client
268            .post(GATEWAY_URL)
269            .header("Accept", "application/json")
270            .header("Content-Type", "application/json")
271            .json(&request_body)
272            .send()
273            .await?;
274
275        if !response.status().is_success() {
276            let text = response.text().await?;
277            error!("Failed to open connection: {}", text);
278            return Err(format!("Failed to open connection: {}", text).into());
279        }
280
281        let connection: ConnectionResponse = response.json().await?;
282
283        info!("Connection established: {:?}", connection);
284
285        Ok(connection)
286    }
287
288    /// Build subscription list
289    fn build_subscriptions(
290        &self,
291    ) -> Result<Vec<Subscription>, Box<dyn std::error::Error + Send + Sync>> {
292        let mut topics = Vec::new();
293
294        // Add event subscription if event handler is registered
295        {
296            let handler = &self.event_handler;
297            if handler.is_some() {
298                topics.push(Subscription {
299                    sub_type: "EVENT".to_string(),
300                    topic: MessageTopic::Callback("*".to_string()),
301                });
302            }
303        }
304
305        // Add callback subscriptions
306        {
307            for topic in self.callback_handlers.keys() {
308                topics.push(Subscription {
309                    sub_type: "CALLBACK".to_string(),
310                    topic: topic.clone(),
311                });
312            }
313        }
314
315        if topics.is_empty() {
316            // Default to all events if no handlers registered
317            topics.push(Subscription {
318                sub_type: "EVENT".to_string(),
319                topic: MessageTopic::Callback("*".to_string()),
320            });
321        }
322        Ok(topics)
323    }
324}