Skip to main content

dingtalk_stream/client/stream_/
lifecycle.rs

1use crate::client::{ConnectionResponse, Subscription};
2use crate::frames::down_message::MessageTopic;
3use crate::utils::get_local_ip;
4use crate::{DingTalkStream, GATEWAY_URL};
5use anyhow::anyhow;
6use futures_util::{SinkExt, Stream, StreamExt};
7use std::fmt::Display;
8use std::sync::atomic::Ordering;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::sync::mpsc::Receiver;
13use tokio::task::JoinHandle;
14use tokio::time::sleep;
15use tokio_tungstenite::connect_async;
16use tokio_tungstenite::tungstenite::Message;
17use tracing::{debug, error, info, warn};
18
19impl DingTalkStream {
20    /// Start the client and run forever (auto-reconnect)
21    pub async fn start(
22        self: Arc<Self>,
23    ) -> crate::Result<(Arc<Self>, JoinHandle<crate::Result<()>>)> {
24        info!("Starting DingTalk Stream client...");
25        let self_ = Arc::clone(&self);
26        let join_handle = tokio::spawn(async move {
27            let _ = self_.lifecycle_listener.on_start(&self_).await;
28            loop {
29                let result = Arc::clone(&self_).connect().await;
30                let _ = self_
31                    .lifecycle_listener
32                    .on_disconnected(&self_, &result)
33                    .await;
34                match result {
35                    Ok(_) => {
36                        info!("Connection closed normally");
37                        let _ = self_.lifecycle_listener.on_stopped(&self_);
38                        return Ok(());
39                    }
40                    Err(e) => {
41                        error!("Connection error: {}", e);
42                        if self_.config.auto_reconnect {
43                            info!(
44                                "Reconnecting in {} seconds...",
45                                self_.config.reconnect_interval.as_secs()
46                            );
47                            sleep(self_.config.reconnect_interval).await;
48                        } else {
49                            let _ = self_.lifecycle_listener.on_stopped(&self_);
50                            return Err(anyhow!(e));
51                        }
52                    }
53                }
54            }
55        });
56        Ok((self, join_handle))
57    }
58}
59
60impl DingTalkStream {
61    /// Connect to DingTalk WebSocket
62    async fn connect(self: Arc<Self>) -> crate::Result<()> {
63        let connection = self.open_connection().await.map_err(|err| anyhow!(err))?;
64        let ws_url = format!("{}?ticket={}", connection.endpoint, connection.ticket);
65        info!("Connecting to WebSocket: {}", ws_url);
66        let _ = self.lifecycle_listener.on_connecting(&self, &ws_url).await;
67        // Connect to WebSocket
68        let (ws_stream, _) = connect_async(&ws_url).await?;
69        let (ws_write, ws_read) = ws_stream.split();
70        self.connected.store(true, Ordering::SeqCst);
71        info!("Connected to DingTalk WebSocket {}", ws_url);
72        let _ = self.lifecycle_listener.on_connected(&self, &ws_url).await;
73        let (ws_write_join_handle, ws_read_handle) = {
74            // Create channel for sending messages
75            let (msg_stream_sender, msg_stream_receiver) = mpsc::channel::<String>(256);
76            let ws_write_join_handle = Arc::clone(&self)
77                .ws_write(ws_write, msg_stream_receiver)
78                .await;
79            // Spawn keep-alive task if enabled
80            let _ = Arc::clone(&self).keepalive(msg_stream_sender.clone()).await;
81            let ws_read_handle = Arc::clone(&self).ws_read(ws_read, msg_stream_sender).await;
82            (ws_write_join_handle, ws_read_handle)
83        };
84
85        if let Ok(exit_normally) = ws_read_handle.await {
86            exit_normally?;
87        }
88        if let Ok(exit_normally) = ws_write_join_handle.await {
89            exit_normally?;
90        }
91        self.connected.store(false, Ordering::SeqCst);
92        self.registered.store(false, Ordering::SeqCst);
93        Ok(())
94    }
95
96    async fn ws_write<Sink>(
97        self: Arc<Self>,
98        mut ws_write: Sink,
99        mut msg_stream_receiver: Receiver<String>,
100    ) -> JoinHandle<crate::Result<()>>
101    where
102        Sink: SinkExt<Message> + Unpin + Send + Sync + 'static,
103        <Sink as futures_util::Sink<Message>>::Error: Display + Into<anyhow::Error> + Send + Sync,
104    {
105        tokio::spawn(async move {
106            while let Some(ref msg) = msg_stream_receiver.recv().await {
107                let result = self.ws_write_with_retry(&mut ws_write, msg).await;
108                let _ = self
109                    .lifecycle_listener
110                    .on_websocket_write(&self, msg, &result)
111                    .await;
112                match result {
113                    Ok(_) => {}
114                    Err(err) => {
115                        return Err(anyhow!(err));
116                    }
117                }
118            }
119            Ok(())
120        })
121    }
122    async fn ws_write_with_retry<W>(&self, ws_write: &mut W, msg: &str) -> crate::Result<()>
123    where
124        W: SinkExt<Message> + Unpin,
125        <W as futures_util::Sink<Message>>::Error: Display + Into<anyhow::Error>,
126    {
127        const TRY_INTERVAL: Duration = Duration::from_secs(1);
128        const TRY_MAX: u8 = 8;
129        info!("Sending message to WebSocket, msg: {}", msg);
130        let mut cnt = 1;
131        loop {
132            let result = ws_write
133                .send(Message::Text(msg.into()))
134                .await
135                .map_err(|err| anyhow!(err));
136            let _ = self
137                .lifecycle_listener
138                .on_websocket_write_with_retry(self, msg, cnt, &result)
139                .await;
140            match result {
141                Ok(_) => {
142                    info!("Success to send message to WebSocket, {}", msg);
143                    return Ok(());
144                }
145                Err(err) => {
146                    if {
147                        cnt += 1;
148                        cnt
149                    } > TRY_MAX
150                    {
151                        warn!("Failed to send message to WebSocket, retrying in 1 second, err: {}, msg: {}", err, msg);
152                        tokio::time::sleep(TRY_INTERVAL).await;
153                        continue;
154                    }
155                    error!(
156                        "Failed to send message to WebSocket, after {} retries, err: {}, msg: {}",
157                        err, cnt, msg
158                    );
159                    return Err(err);
160                }
161            }
162        }
163    }
164
165    async fn ws_read<R, E>(
166        self: Arc<Self>,
167        mut ws_read: R,
168        msg_stream_sender: mpsc::Sender<String>,
169    ) -> JoinHandle<crate::Result<()>>
170    where
171        E: Display + Into<anyhow::Error> + Send + Sync,
172        R: Stream<Item = Result<Message, E>> + Unpin + Send + Sync + 'static,
173    {
174        tokio::spawn(async move {
175            while let Some(result) = ws_read.next().await {
176                let result = result.map_err(|err| anyhow!(err));
177                let _ = self
178                    .lifecycle_listener
179                    .on_websocket_read(&self, &result)
180                    .await;
181                match result {
182                    Ok(Message::Text(text)) => {
183                        info!("Received text message: {}", text);
184                        if let Err(e) = self.handle_message(&text, msg_stream_sender.clone()).await
185                        {
186                            error!("Error handling message: {}", e);
187                        }
188                    }
189                    Ok(Message::Close(_)) => {
190                        warn!("Received close message: WebSocket connection will be closed!!!");
191                        return Ok(());
192                    }
193                    Err(err) => {
194                        error!("WebSocket error: {}", err);
195                        return Err(err);
196                    }
197                    _ => continue,
198                }
199            }
200            unreachable!()
201        })
202    }
203
204    async fn keepalive(self: Arc<Self>, msg_stream_sender: mpsc::Sender<String>) -> JoinHandle<()> {
205        tokio::spawn(async move {
206            loop {
207                tokio::time::sleep(self.config.keep_alive_interval).await;
208                const PING: &str = r#"{"code": 200,"message": "ping"}"#;
209                let result = msg_stream_sender
210                    .send(PING.into())
211                    .await
212                    .map_err(|err| anyhow!(err));
213                let _ = &self
214                    .lifecycle_listener
215                    .on_keepalive(&self, PING, &result)
216                    .await;
217                match result {
218                    Ok(_) => {
219                        continue;
220                    }
221                    Err(err) => {
222                        warn!("stream_tx dropped error, keepalive task stopping. err: {err}");
223                        return;
224                    }
225                }
226            }
227        })
228    }
229
230    /// Open connection to DingTalk
231    async fn open_connection(
232        &self,
233    ) -> Result<ConnectionResponse, Box<dyn std::error::Error + Send + Sync>> {
234        let subscriptions = self.build_subscriptions()?;
235
236        let client = &self.http_client;
237        let local_ip = get_local_ip().unwrap_or_else(|| "127.0.0.1".to_string());
238
239        let request_body = serde_json::json!({
240            "clientId": self.credential.client_id,
241            "clientSecret": self.credential.client_secret,
242            "subscriptions": subscriptions,
243            "ua": self.config.ua,
244            "localIp": local_ip,
245        });
246
247        info!("Opening connection to {}", GATEWAY_URL);
248        debug!("Request body: {:?}", request_body);
249
250        let response = client
251            .post(GATEWAY_URL)
252            .header("Accept", "application/json")
253            .header("Content-Type", "application/json")
254            .json(&request_body)
255            .send()
256            .await?;
257
258        if !response.status().is_success() {
259            let text = response.text().await?;
260            error!("Failed to open connection: {}", text);
261            return Err(format!("Failed to open connection: {}", text).into());
262        }
263
264        let connection: ConnectionResponse = response.json().await?;
265
266        info!("Connection established: {:?}", connection);
267
268        Ok(connection)
269    }
270
271    /// Build subscription list
272    fn build_subscriptions(
273        &self,
274    ) -> Result<Vec<Subscription>, Box<dyn std::error::Error + Send + Sync>> {
275        let mut topics = Vec::new();
276
277        // Add event subscription if event handler is registered
278        {
279            let handler = &self.event_handler;
280            if handler.is_some() {
281                topics.push(Subscription {
282                    sub_type: "EVENT".to_string(),
283                    topic: MessageTopic::Callback("*".to_string()),
284                });
285            }
286        }
287
288        // Add callback subscriptions
289        {
290            for topic in self.callback_handlers.keys() {
291                topics.push(Subscription {
292                    sub_type: "CALLBACK".to_string(),
293                    topic: topic.clone(),
294                });
295            }
296        }
297
298        if topics.is_empty() {
299            // Default to all events if no handlers registered
300            topics.push(Subscription {
301                sub_type: "EVENT".to_string(),
302                topic: MessageTopic::Callback("*".to_string()),
303            });
304        }
305        Ok(topics)
306    }
307}