atc/
lib.rs

1pub mod libs;
2
3pub use libs::{
4    command::{ChannelCommand, ServerCommand},
5    connector::create_connector,
6    frame::Frame,
7    listener::create_listener,
8};
9
10use log::{debug, error, info, warn};
11use queues::{IsQueue, Queue};
12use std::{collections::BTreeMap, sync::Arc};
13use tokio::{
14    io,
15    sync::{
16        mpsc::{channel, Receiver, Sender},
17        Mutex,
18    },
19    time::Instant,
20};
21
22#[derive(Clone)]
23pub struct QueuedMessage {
24    pub content: String,
25    hit_count: u32,
26}
27
28impl QueuedMessage {
29    pub fn new(content: String) -> Self {
30        Self {
31            content,
32            hit_count: 0,
33        }
34    }
35}
36
37pub struct ChannelInfo {
38    last_write: Instant,
39    messages: Queue<QueuedMessage>,
40}
41
42impl ChannelInfo {
43    pub fn new(message: String) -> Self {
44        let mut messages: Queue<QueuedMessage> = Queue::new();
45        messages.add(QueuedMessage::new(message)).unwrap();
46        Self {
47            last_write: Instant::now(),
48            messages,
49        }
50    }
51
52    /// Add element into a queue.
53    pub fn enqueue(&mut self, message: String) {
54        self.messages.add(QueuedMessage::new(message)).unwrap();
55    }
56
57    /// Get head element of a queue, this doesn't remove it from queue.
58    ///
59    /// Any elements that has been visited more than 16 times will be forced to
60    /// dequeue.
61    pub fn head(&mut self) -> Option<String> {
62        if let Ok(mut val) = self.messages.peek() {
63            if val.hit_count >= 16 {
64                return if self.messages.size() > 1 {
65                    self.dequeue();
66                    self.head()
67                } else {
68                    None
69                };
70            }
71
72            val.hit_count += 1;
73            Some(val.content)
74        } else {
75            None
76        }
77    }
78
79    pub fn update_instant(&mut self) {
80        self.last_write = Instant::now();
81    }
82
83    /// You should call head to get the head element and then call this method
84    /// to remove it from queue.
85    pub fn dequeue(&mut self) -> bool {
86        self.messages.remove().is_ok()
87    }
88
89    pub fn len(&self) -> usize {
90        self.messages.size()
91    }
92}
93
94pub struct Server {
95    socket_map: Arc<Mutex<BTreeMap<String, Sender<ChannelCommand>>>>,
96    rx_upstream: Arc<Mutex<Receiver<(String, Option<Sender<ChannelCommand>>)>>>,
97    tx_upstream: Sender<(String, Option<Sender<ChannelCommand>>)>,
98    rx_control: Arc<Mutex<Receiver<ServerCommand>>>,
99    uri: String,
100    server_started: Arc<Mutex<bool>>,
101    flag_interrupt: Arc<Mutex<bool>>,
102}
103
104impl Server {
105    pub fn new(uri: String, rx_ctrl: Receiver<ServerCommand>) -> Self {
106        let (tx, rx) = channel(1024);
107        Self {
108            uri,
109            socket_map: Arc::new(Mutex::new(BTreeMap::<String, Sender<ChannelCommand>>::new())),
110            rx_upstream: Arc::new(Mutex::new(rx)),
111            tx_upstream: tx,
112            server_started: Arc::new(Mutex::new(false)),
113            flag_interrupt: Arc::new(Mutex::new(false)),
114            rx_control: Arc::new(Mutex::new(rx_ctrl)),
115        }
116    }
117
118    pub async fn start(&mut self) -> io::Result<()> {
119        let rx_upstream = self.rx_upstream.clone();
120        let socket_map = self.socket_map.clone();
121        let flag_int = self.flag_interrupt.clone();
122        let rx_control = self.rx_control.clone();
123
124        let message_queue: Arc<Mutex<BTreeMap<String, ChannelInfo>>> =
125            Arc::new(Mutex::new(BTreeMap::new()));
126
127        tokio::spawn(async move {
128            loop {
129                let flag_int_guard = flag_int.lock().await;
130                if *flag_int_guard == true {
131                    break;
132                }
133                drop(flag_int_guard);
134
135                let mut rx_upstream_guard = rx_upstream.lock().await;
136                if let Ok((channel_id, sender)) = rx_upstream_guard.try_recv() {
137                    match sender {
138                        None => {
139                            debug!(target: "atc-listener", "Removing {} from socket map", channel_id);
140                            socket_map.lock().await.remove(&channel_id);
141                        }
142                        Some(sender) => {
143                            debug!(target: "atc-listener", "Adding {} to socket map", channel_id);
144                            socket_map.lock().await.insert(channel_id.clone(), sender);
145                        }
146                    };
147                }
148                drop(rx_upstream_guard);
149
150                // Handle queued messages.
151                // For each channel with queued messages, only one message will
152                // be sent to the channel.
153                let mut message_queue_guard = message_queue.lock().await;
154                let socket_map_guard = socket_map.lock().await;
155                let mut should_drop_channel_ids = vec![];
156                for (channel_id, channel_info) in message_queue_guard.iter_mut() {
157                    if channel_info.last_write.elapsed().as_secs() > 60 {
158                        // Last write operation was 60 seconds ago, should not
159                        // queue this anymore, drop all existing messages.
160                        should_drop_channel_ids.push(channel_id.clone());
161                        warn!(target: "atc-listener", "Message queue of channel (`{}`) will be dropped due to inactivity for more than 60 seconds", channel_id);
162                        continue;
163                    }
164                    if channel_info.len() > 128 {
165                        // Too many queued messages, drop all existing messages.
166                        should_drop_channel_ids.push(channel_id.clone());
167                        warn!(target: "atc-listener", "Message queue of channel (`{}`) will be dropped due to exceeds 128 message limit", channel_id);
168                        continue;
169                    }
170                    if socket_map_guard.contains_key(channel_id) && channel_info.len() > 0 {
171                        // Should try send message.
172                        let oldest_msg = channel_info.head().unwrap();
173                        let sender = socket_map_guard.get(channel_id).unwrap().clone();
174                        if let Ok(_) = sender
175                            .send(ChannelCommand::ChannelMessage((
176                                channel_id.clone(),
177                                oldest_msg,
178                            )))
179                            .await
180                        {
181                            channel_info.dequeue();
182                            channel_info.update_instant();
183                            debug!(target: "atc-listener", "One queued message sent to existing channel `{}`.", channel_id);
184                        } else {
185                            warn!(target: "atc-listener", "One queued message not sent to existing channel `{}`.", channel_id);
186                        }
187                    }
188                }
189                should_drop_channel_ids.iter().for_each(|id| {
190                    message_queue_guard.remove(id);
191                });
192                drop(socket_map_guard);
193                drop(message_queue_guard);
194
195                let mut rx_control_guard = rx_control.lock().await;
196                if let Ok(server_cmd) = rx_control_guard.try_recv() {
197                    if let ServerCommand::Terminate = server_cmd.clone() {
198                        let mut flag_int_guard = flag_int.lock().await;
199                        *flag_int_guard = true;
200                        drop(flag_int_guard);
201                        return;
202                    }
203
204                    let (target_channel_id, msg) = match server_cmd.clone() {
205                        ServerCommand::Message(t, m) => (t, m),
206                        _ => {
207                            panic!("Need to be handled before entering this LOC")
208                        }
209                    };
210
211                    // Check if target channel id exists in `socket_map`, or
212                    // Add to queue and wait for a while.
213                    if let Some(target_channel_id) = target_channel_id {
214                        let socket_map_guard = socket_map.lock().await;
215                        if socket_map_guard.contains_key(&target_channel_id) {
216                            let sender = socket_map_guard.get(&target_channel_id).unwrap().clone();
217                            if let Err(e) = sender
218                                .send(ChannelCommand::ChannelMessage((
219                                    target_channel_id.clone(),
220                                    msg.clone(),
221                                )))
222                                .await
223                            {
224                                warn!(target: "atc-listener", "Error sending to message channel [will be queued]: {:?}", e);
225
226                                // Put to message queue.
227                                let mut queue = message_queue.lock().await;
228                                if queue.contains_key(&target_channel_id) {
229                                    queue.get_mut(&target_channel_id).unwrap().enqueue(msg);
230                                } else {
231                                    queue.insert(target_channel_id, ChannelInfo::new(msg));
232                                }
233                                drop(queue);
234                            } else {
235                                debug!(target: "atc-listener", "Message sent to `{}`", target_channel_id);
236                            };
237                        } else {
238                            // Channel ID doesn't exist, put into queue.
239                            let mut queue = message_queue.lock().await;
240                            if queue.contains_key(&target_channel_id) {
241                                queue.get_mut(&target_channel_id).unwrap().enqueue(msg);
242                            } else {
243                                queue.insert(target_channel_id, ChannelInfo::new(msg));
244                            }
245                            drop(queue);
246                        }
247                        drop(socket_map_guard);
248                    } else {
249                        // Note: messages without target channel id will not be
250                        // queued.
251                    }
252                }
253                drop(rx_control_guard);
254            }
255        });
256
257        let uri = self.uri.clone();
258        let server_started = self.server_started.clone();
259        let tx_upstream = self.tx_upstream.clone();
260
261        info!(target: "atc-listener", "Ready to start server `{}`:", uri.clone());
262        *server_started.lock().await = true;
263        if let Err(e) = create_listener(uri.clone(), tx_upstream, self.flag_interrupt.clone()).await
264        {
265            error!(target: "atc-listener", "Unable to bind to `{}`: `{:?}`", uri.clone(), e );
266            *server_started.lock().await = false;
267        }
268
269        Ok(())
270    }
271}
272
273pub struct Client {
274    rx_control: Arc<Mutex<Receiver<ChannelCommand>>>,
275    tx_control: Sender<ChannelCommand>,
276    rx_message: Arc<Mutex<Receiver<ChannelCommand>>>,
277    tx_message: Sender<ChannelCommand>,
278    rx_outer_control: Arc<Mutex<Receiver<ServerCommand>>>,
279    uri: String,
280    pub id: String,
281    callback_handler: Arc<Mutex<Option<ClientCallbackHandler>>>,
282    flag_interupt: Arc<Mutex<bool>>,
283    should_reconnect: bool,
284}
285
286pub enum ClientCallbackHandler {
287    Closure(Box<dyn FnMut(String, String) + Send>),
288    Channel(Sender<(String, String)>),
289}
290
291impl Client {
292    pub fn new(uri: String, id: String, rx_outer_ctrl: Receiver<ServerCommand>) -> Self {
293        let (tx_ctrl, rx_ctrl) = channel::<ChannelCommand>(1);
294        let (tx_msg, rx_msg) = channel::<ChannelCommand>(1);
295        Self {
296            rx_control: Arc::new(Mutex::new(rx_ctrl)),
297            tx_control: tx_ctrl,
298            rx_message: Arc::new(Mutex::new(rx_msg)),
299            tx_message: tx_msg,
300            uri,
301            id: id,
302            callback_handler: Arc::new(Mutex::new(None)),
303            rx_outer_control: Arc::new(Mutex::new(rx_outer_ctrl)),
304            flag_interupt: Arc::new(Mutex::new(false)),
305            should_reconnect: false,
306        }
307    }
308
309    pub fn reconnect(self, should_reconnect: bool) -> Self {
310        Self {
311            should_reconnect,
312            ..self
313        }
314    }
315
316    pub async fn callback(self, cb: impl FnMut(String, String) + Send + 'static) -> Self {
317        *self.callback_handler.lock().await = Some(ClientCallbackHandler::Closure(Box::new(cb)));
318        self
319    }
320
321    pub async fn sender(self, sender: Sender<(String, String)>) -> Self {
322        *self.callback_handler.lock().await = Some(ClientCallbackHandler::Channel(sender));
323        self
324    }
325
326    pub async fn connect(&mut self) -> io::Result<()> {
327        // Create another async task that always execute.
328        // Move the message channel receiver and move a clone of control channel
329        // sender into the task
330        let tx_ctrl = self.tx_control.clone();
331        let rx_msg = self.rx_message.clone();
332        let rx_outer_ctrl = self.rx_outer_control.clone();
333        let callback_handler = self.callback_handler.clone();
334        let flag_int = self.flag_interupt.clone();
335
336        tokio::spawn(async move {
337            let mut last_ping = Instant::now();
338
339            loop {
340                {
341                    if *flag_int.lock().await == true {
342                        return;
343                    }
344                }
345                if last_ping.elapsed().as_secs() >= 5 {
346                    match tx_ctrl.send(ChannelCommand::Ping).await {
347                        Ok(_) => {
348                            last_ping = Instant::now();
349                        }
350                        Err(e) => {
351                            error!(target: "atc-connector", "Unable to initualize PING command: {:?}", e);
352                            return;
353                        }
354                    };
355                }
356                {
357                    if let Ok(data) = rx_outer_ctrl.lock().await.try_recv() {
358                        if data == ServerCommand::Terminate {
359                            *flag_int.lock().await = true;
360                            return;
361                        } else if let ServerCommand::Identify(id) = data {
362                            tx_ctrl.send(ChannelCommand::Identify(id)).await.unwrap();
363                        }
364                    }
365                }
366
367                {
368                    if let Ok(data) = rx_msg.lock().await.try_recv() {
369                        if let ChannelCommand::ChannelMessage((channel_id, message)) = data.clone()
370                        {
371                            if let Some(handler) = callback_handler.lock().await.as_mut() {
372                                match handler {
373                                    ClientCallbackHandler::Closure(closure) => {
374                                        closure(channel_id, message)
375                                    }
376                                    ClientCallbackHandler::Channel(sender) => {
377                                        sender.send((channel_id, message)).await.unwrap()
378                                    }
379                                }
380                            }
381                        }
382                    };
383                }
384            }
385        });
386        let mut reconnect_attempts = 0;
387        loop {
388            let id = self.id.clone();
389            let rx_ctrl = self.rx_control.clone();
390            let tx_msg = self.tx_message.clone();
391            let uri = self.uri.clone();
392            let flag_int_clone = self.flag_interupt.clone();
393
394            tokio::spawn(async move {
395                if let Err(e) = create_connector(uri.clone(), id, rx_ctrl, tx_msg, flag_int_clone).await {
396                    error!(target: "atc-connector", "Unable to connect to remote server `{}`: {:?}",uri, e);
397                }
398            }).await.unwrap();
399
400            if !self.should_reconnect || *self.flag_interupt.lock().await {
401                warn!(target: "atc-connector", "Reconnect not enabled or user requested termination from client side.");
402                break;
403            }
404            if reconnect_attempts > 8 {
405                warn!(target: "atc-connector", "No more reconnecting after 8 attempts.");
406                break;
407            }
408
409            // `should_reconnect` flag set to true, and interrupt flag not set
410            // will restart another connection.
411            info!(target: "atc-connector", "Client connection restarting");
412            reconnect_attempts += 1;
413        }
414        Ok(())
415    }
416}