rs_pkg/network/websocket/client/
client.rs

1use super::{
2    super::{ErrorHandlerType, MessageHandlerType},
3    WebSocketClientConfig,
4};
5use crate::{
6    async_fn::wrap_fn,
7    cron::{Cron, CronConfig},
8    network::websocket::BytesGenerator,
9};
10use bytes::Bytes;
11use futures_util::{SinkExt, StreamExt};
12use std::sync::Arc;
13use std::{error::Error, time::Duration};
14use tokio::{
15    select,
16    sync::{
17        Mutex,
18        mpsc::{Receiver, Sender, error::SendError},
19    },
20};
21use tokio_tungstenite::{connect_async, tungstenite::Message};
22use tracing::{debug, error, warn};
23
24#[derive(Clone)]
25pub struct Client {
26    name: String,
27    addr: String,
28    message_handler: Arc<MessageHandlerType<Message>>,
29    error_handler: Arc<ErrorHandlerType>,
30    ping_payload: Arc<BytesGenerator>,
31    // worker: Arc<Mutex<Worker<()>>>,
32    ping_interval: String,
33
34    client_close: Arc<Sender<()>>,
35    client_done: Arc<Mutex<Receiver<()>>>,
36
37    reconnect: bool,
38    reconnect_sender: Arc<Sender<()>>,
39    reconnect_receiver: Arc<Mutex<Receiver<()>>>,
40
41    message_sender: Arc<Sender<Message>>,
42    message_receiver: Arc<Mutex<Receiver<Message>>>,
43}
44
45impl Client {
46    pub fn new(name: &str, cfg: &WebSocketClientConfig) -> Self {
47        let (sender, receiver) = tokio::sync::mpsc::channel(1);
48        let (client_close, client_done) = tokio::sync::mpsc::channel(1);
49        let (reconnect_sender, reconnect_receiver) = tokio::sync::mpsc::channel(1);
50        Self {
51            name: name.to_string(),
52            addr: cfg.addr.clone(),
53            message_handler: wrap_fn(|msg| async {
54                match msg {
55                    Message::Text(t) => debug!("Received text: {}", t),
56                    Message::Binary(b) => debug!("Received binary: {:?}", b),
57                    Message::Ping(p) => debug!("Received ping: {:?}", p),
58                    Message::Pong(p) => debug!("Received pong: {:?}", p),
59                    Message::Close(c) => debug!("Received close: {:?}", c),
60                    Message::Frame(f) => debug!("Received frame: {:?}", f),
61                }
62                None
63            }),
64            error_handler: wrap_fn(|e| async move { error!("Received error: {}", e) }),
65            ping_payload: wrap_fn(|_| async {
66                let ts = chrono::Utc::now().timestamp().to_string();
67                Bytes::from(ts)
68            }),
69
70            // worker: Arc::new(Mutex::new(Worker::new(name, 1))),
71            ping_interval: cfg.ping_interval.clone(),
72
73            reconnect: cfg.reconnect,
74            reconnect_sender: Arc::new(reconnect_sender),
75            reconnect_receiver: Arc::new(Mutex::new(reconnect_receiver)),
76
77            client_close: Arc::new(client_close),
78            client_done: Arc::new(Mutex::new(client_done)),
79
80            message_sender: Arc::new(sender),
81            message_receiver: Arc::new(Mutex::new(receiver)),
82        }
83    }
84
85    pub async fn stop(&self) -> Result<(), SendError<()>> {
86        self.client_close.send(()).await
87    }
88
89    pub fn with_message_handler<F, Fut>(mut self, h: F) -> Self
90    where
91        F: Fn(Message) -> Fut + Send + Sync + 'static,
92        Fut: Future<Output = Option<Message>> + Send + Sync + 'static,
93    {
94        self.message_handler = wrap_fn(h);
95        self
96    }
97
98    pub fn with_error_handler<F, Fut>(mut self, h: F) -> Self
99    where
100        F: Fn(Box<dyn Error + Send + Sync + 'static>) -> Fut + Send + Sync + 'static,
101        Fut: Future<Output = ()> + Send + Sync + 'static,
102    {
103        self.error_handler = wrap_fn(h);
104        self
105    }
106
107    pub fn with_ping_payload<F, Fut>(mut self, h: F) -> Self
108    where
109        F: Fn() -> Fut + Send + Sync + 'static,
110        Fut: Future<Output = Bytes> + Send + Sync + 'static,
111    {
112        let h = Arc::new(h);
113        self.ping_payload = wrap_fn(move |_| {
114            let h = h.clone();
115            async move { h().await }
116        });
117        self
118    }
119
120    async fn connect(&mut self, done: Arc<Mutex<Receiver<()>>>) {
121        let reconnect_sender = self.reconnect_sender.clone();
122
123        if let Ok((stream, _)) = connect_async(&self.addr)
124            .await
125            .inspect_err(|e| error!("[{}] connect to {} failed: {}", self.name, self.addr, e))
126        {
127            let (sink, stream) = stream.split();
128            let sink = Arc::new(Mutex::new(sink));
129            let stream = Arc::new(Mutex::new(stream));
130            let msg_handler = self.message_handler.clone();
131            let msg_receiver = self.message_receiver.clone();
132            let err_handler_ping = self.error_handler.clone();
133            let err_handler_main = self.error_handler.clone();
134
135            let mut cron_cfg = CronConfig::default();
136            cron_cfg.interval = self.ping_interval.clone();
137            cron_cfg.run_after_start = self.ping_interval.clone();
138            cron_cfg.interval_after_finish = false;
139
140            let cron = Cron::new("PING", &cron_cfg);
141            let msg_sender = self.message_sender.clone();
142            cron.run(move || {
143                let msg_sender = msg_sender.clone();
144                let err_handler_ping = err_handler_ping.clone();
145                let now = chrono::Utc::now().timestamp_millis().to_string();
146                let ping = Message::Ping(Bytes::from(now));
147                async move {
148                    if let Err(err) = msg_sender.send(ping).await {
149                        _ = err_handler_ping(Box::new(err)).await;
150                    }
151                }
152            })
153            .await;
154
155            tokio::spawn(async move {
156                let mut sink = sink.lock().await;
157                let mut guard = stream.lock().await;
158                let mut done = done.lock().await;
159                let mut msg_receiver = msg_receiver.lock().await;
160                loop {
161                    select! {
162                        _ = done.recv() => {
163                            warn!("Conn Exit with done");
164                            return
165                        },
166
167                        msg = msg_receiver.recv() => {
168                            debug!("msg_receiver receive: {:?}", msg);
169                            match msg {
170                                Some(msg) => {
171                                    if let Err(e) = sink.send(msg).await {
172                                        err_handler_main(Box::new(e)).await;
173                                        _ = reconnect_sender.send(()).await;
174                                        return
175                                    };
176                                },
177
178                                None => {
179                                    _ = reconnect_sender.send(()).await;
180                                    return
181                                },
182                            }
183                        }
184
185                        t = guard.next() => {
186                            debug!("stream receive: {:?}", t);
187                            match t {
188                                Some(Ok(msg)) => {
189                                    if let Some(msg) = msg_handler(msg).await {
190                                        if let Err(e) = sink.send(msg).await {
191                                            err_handler_main(Box::new(e)).await;
192                                            _ = reconnect_sender.send(()).await;
193                                            return
194                                        };
195                                    }
196                                },
197                                Some(Err(err)) => {
198                                    err_handler_main(Box::new(err)).await;
199                                    _ = reconnect_sender.send(()).await;
200                                    return
201                                },
202                                None => {
203                                    _ = reconnect_sender.send(()).await;
204                                    return
205                                }
206                            }
207                        }
208                    }
209                }
210            });
211            return;
212        }
213
214        _ = reconnect_sender.send(()).await;
215    }
216
217    pub async fn send_message(&self, msg: Message) {
218        _ = self.message_sender.send(msg).await;
219    }
220
221    pub async fn run(&self) {
222        let s = self.clone();
223        s.clone().connect(self.client_done.clone()).await;
224
225        if self.reconnect {
226            tokio::spawn(async move {
227                let s = s.clone();
228                let mut reconnect_guard = s.reconnect_receiver.lock().await;
229                let done = s.client_done.clone();
230                loop {
231                    _ = reconnect_guard.recv().await;
232                    tokio::time::sleep(Duration::from_secs(1)).await;
233                    s.clone().connect(done.clone()).await;
234                }
235            });
236        }
237    }
238}