liquid_ml/network/
client.rs

1//! Represents a client node in a distributed system, with implementations
2//! provided for `LiquidML` use cases.
3use crate::error::LiquidError;
4use crate::network::{
5    existing_conn_err, increment_msg_id, message, Connection, ControlMsg,
6    FramedSink, FramedStream, Message, MessageCodec,
7};
8use futures::{
9    stream::{self, SelectAll},
10    SinkExt,
11};
12use log::{debug, info};
13use serde::de::DeserializeOwned;
14use serde::Serialize;
15use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tokio::io;
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::{Mutex, Notify};
21use tokio_util::codec::{FramedRead, FramedWrite};
22
23/// Represents a `Client` node in a distributed system that is generic for
24/// type `T`, where `T` is the types of messages that can be sent between
25/// `Client`s. Allows directed communication to any other node that shares the
26/// `Client`'s `client_type`, which enables increased concurrency due to
27/// decreased lock contention.
28#[derive(Debug)]
29pub struct Client<T> {
30    /// The `id` of this `Client`, assigned by the [`Server`] on startup
31    /// and is monotonically increasing based on the order of connections
32    ///
33    /// [`Server`]: struct.Server.html
34    pub id: usize,
35    /// The number of `Client`s in the network
36    pub num_nodes: usize,
37    /// The `address` of this `Client`
38    pub address: SocketAddr,
39    /// The id of the current message
40    pub(crate) msg_id: usize,
41    /// A directory which is a map of client id to the [`Connection`] with that
42    /// `Client`
43    ///
44    /// [`Connection`]: struct.Connection.html
45    pub(crate) directory: HashMap<usize, Connection<T>>,
46    /// The connection to the [`Server`](struct.Server.html)
47    server: Connection<ControlMsg>,
48    /// The name of the network this `Client` will connect to. This is so that,
49    /// for example, two different communication networks of
50    /// `Client<DistributedDFMsg>` can be created so that separate
51    /// `DistributedDataFrame`s only talk to themselves.
52    ///
53    /// This allows increased concurrency since each `Client` owned by
54    /// different components have their own `Mutex` around them, instead of a
55    /// single `Client` with one `Mutex`.
56    network_name: String,
57}
58
59// TODO: remove `DeserializeOwned + 'static`
60impl<RT: Send + Sync + DeserializeOwned + Serialize + Clone + 'static>
61    Client<RT>
62{
63    /// Create a new [`Client`] and connect to all other nodes in the network
64    /// with the given `network_name`. If you wish to create multiple networks
65    /// **and** preserve the `node_id`s assigned by the [`Server`], you should
66    /// check out the [register_network] method
67    ///
68    /// # Parameters
69    /// - `server_addr`: The address of the [`Server`] in `IP:Port` format
70    /// - `my_ip`: The `IP` of this [`Client`]
71    /// - `my_port`: An optional port for this [`Client`] to listen for new
72    ///              connections. If its `None`, uses the OS to randomly assign
73    ///              a port.
74    /// - `num_nodes`: The number of nodes in the network.
75    /// - `network_name`: The name of the network to connect with, will only
76    ///                   connect with other `Client`s with the same
77    ///                   `network_name`
78    /// # Returned Values
79    /// This function returns a tuple of three things, the first element is the
80    /// `Client`, which can then be used to send messages to any other node
81    /// with the `send_msg` method.  The second element is a struct that
82    /// implements [`StreamExt`] and combines the streams of all messages from
83    /// all other nodes (unordered), which you can use to easily process
84    /// messages like this:
85    ///
86    /// ```ignore
87    /// while let Some(Ok(msg)) = streams.next().await {
88    ///     // ... process the message here according to your use case
89    /// }
90    /// ```
91    ///
92    /// The third element is a notifier that will only be notified when the
93    /// [`Server`] sends a `ControlMsg::Kill` message to the `Client`.
94    ///
95    /// [`Client`]: struct.Client.html
96    /// [`Server`]: struct.Server.html
97    /// [`StreamExt`]: https://docs.rs/futures/0.3.4/futures/stream/trait.StreamExt.html
98    /// [`ControlMsg::Directory`]: enum.ControlMsg.html#variant.Directory
99    /// [`ControlMsg::Introduction`]: enum.ControlMsg.html#variant.Introduction
100    /// [`ControlMsg::Kill`]: enum.ControlMsg.html#variant.Kill
101    /// [`register_network`]: struct.Client.html#method.register_network
102    pub async fn new(
103        server_addr: String,
104        my_ip: String,
105        my_port: Option<String>,
106        num_nodes: usize,
107        network_name: String,
108    ) -> Result<
109        (Arc<Mutex<Self>>, SelectAll<FramedStream<RT>>, Arc<Notify>),
110        LiquidError,
111    > {
112        // Setup a TCPListener
113        let listener;
114        let my_address: SocketAddr = match my_port {
115            Some(port) => {
116                let addr = format!("{}:{}", my_ip, port);
117                listener = TcpListener::bind(&addr).await?;
118                addr.parse().unwrap()
119            }
120            None => {
121                let addr = format!("{}:0", my_ip);
122                listener = TcpListener::bind(&addr).await?;
123                listener.local_addr()?.to_string().parse().unwrap()
124            }
125        };
126        // Connect to the server
127        let server_stream = TcpStream::connect(server_addr).await?;
128        let server_address = server_stream.peer_addr().unwrap();
129        let (reader, writer) = io::split(server_stream);
130        let mut stream = FramedRead::new(reader, MessageCodec::new());
131        let sink = FramedWrite::new(writer, MessageCodec::new());
132        let mut server = Connection {
133            address: server_address,
134            sink,
135        };
136        // Tell the server our address and type
137        server
138            .sink
139            .send(Message::new(
140                0,
141                0,
142                0,
143                ControlMsg::Introduction {
144                    address: my_address,
145                    network_name: network_name.to_string(),
146                },
147            ))
148            .await?;
149        // Server responds with the addresses of all currently connected clients
150        let dir_msg = message::read_msg(&mut stream).await?;
151        let dir = if let ControlMsg::Directory { dir } = dir_msg.msg {
152            dir
153        } else {
154            return Err(LiquidError::UnexpectedMessage);
155        };
156
157        info!(
158            "Client in network {} got id {} running at address {}",
159            network_name, dir_msg.target_id, &my_address
160        );
161
162        // initialize `self`
163        let mut c = Client {
164            id: dir_msg.target_id,
165            address: my_address,
166            msg_id: dir_msg.msg_id + 1,
167            directory: HashMap::new(),
168            num_nodes,
169            server,
170            network_name: network_name.to_string(),
171        };
172
173        // Connect to all the currently existing clients
174        let mut existing_conns = vec![];
175        // note this is done serially and could be done concurrently, but
176        // it doesn't make a difference since there will only ever be a
177        // (relatively) small number of nodes
178        for (id, addr) in dir.into_iter() {
179            existing_conns.push(c.connect(id, addr).await?);
180        }
181
182        // Listen for further messages from the Server, e.g. `Kill` messages
183        let kill_notifier = Arc::new(Notify::new());
184        Client::<ControlMsg>::recv_server_msg(stream, kill_notifier.clone());
185        // block until all the other clients start up and connect to us
186        let new_conns =
187            Client::accept_new_connections(&mut c, listener, num_nodes).await?;
188        let read_streams = stream::select_all(
189            existing_conns.into_iter().chain(new_conns.into_iter()),
190        );
191
192        let concurrent_client = Arc::new(Mutex::new(c));
193        Ok((concurrent_client, read_streams, kill_notifier))
194    }
195
196    /// Given an already connected `Client` of any type, register a new network
197    /// with the given `network_name` that will create a new network of
198    /// `Client`s that preserve `node_id`s across all nodes by connecting in
199    /// the same order as the `parent`. The new `Client` can only talk to
200    /// `Client`s with the same `network_name` as the new network is
201    /// independent of the `parent`.
202    ///
203    /// The tuple returned is the same as in the `Client::new` function.
204    pub async fn register_network<
205        T: Send + Sync + DeserializeOwned + Serialize + Clone + 'static,
206    >(
207        parent: Arc<Mutex<Self>>,
208        network_name: String,
209    ) -> Result<
210        (
211            Arc<Mutex<Client<T>>>,
212            SelectAll<FramedStream<T>>,
213            Arc<Notify>,
214        ),
215        LiquidError,
216    > {
217        let (server_addr, my_ip, node_id, listen_addr, num_nodes) = {
218            let unlocked = parent.lock().await;
219            let node_id = unlocked.id;
220            let server_addr = unlocked.server.address.to_string();
221            let my_ip = unlocked.address.ip().to_string();
222            let num_nodes = unlocked.num_nodes;
223            (server_addr, my_ip, node_id, unlocked.address, num_nodes)
224        };
225        if node_id == 1 {
226            // connect our client right away since we want to be node 1
227            let jh = tokio::spawn(async move {
228                Client::<T>::new(
229                    server_addr,
230                    my_ip,
231                    None,
232                    num_nodes,
233                    network_name,
234                )
235                .await
236            });
237            // Send a ready message to node 2 so that all the other nodes
238            // start connecting to the Server in the correct order
239            let node_2_addr = {
240                let unlocked = parent.lock().await;
241                unlocked.directory.get(&2).unwrap().address
242            };
243            let socket = TcpStream::connect(node_2_addr).await?;
244            let (_, writer) = io::split(socket);
245            let mut sink =
246                FramedWrite::new(writer, MessageCodec::<ControlMsg>::new());
247            let msg = Message::new(0, node_id, 2, ControlMsg::Ready);
248            sink.send(msg).await?;
249            let (network, read_streams, kill_notifier) = jh.await.unwrap()?;
250            assert_eq!(1, { network.lock().await.id });
251            // return the newly registered network
252            Ok((network, read_streams, kill_notifier))
253        } else {
254            // wait to receive a `Ready` message from the node before us
255            // the `parent` passed in
256            let mut listener = TcpListener::bind(listen_addr).await?;
257            let (socket, _) = listener.accept().await?;
258            let (reader, writer) = io::split(socket);
259            let mut stream =
260                FramedRead::new(reader, MessageCodec::<ControlMsg>::new());
261            let mut sink =
262                FramedWrite::new(writer, MessageCodec::<ControlMsg>::new());
263            // wait for the ready message
264            let msg = message::read_msg(&mut stream).await?;
265            //assert_eq!(msg.sender_id, node_id);
266            match msg.msg {
267                ControlMsg::Ready => (),
268                _ => return Err(LiquidError::UnexpectedMessage),
269            };
270            // The node before us has joined the network, it is now time
271            // to connect
272            let client_join_handle = tokio::spawn(async move {
273                Client::<T>::new(
274                    server_addr,
275                    my_ip,
276                    None,
277                    num_nodes,
278                    network_name,
279                )
280                .await
281            });
282
283            // tell the next node we are ready
284            if node_id < num_nodes {
285                // There is another node after us
286                let msg = Message::new(0, node_id, node_id, ControlMsg::Ready);
287                sink.send(msg).await?;
288                let next_node_addr = {
289                    let unlocked = parent.lock().await;
290                    unlocked.directory.get(&(node_id + 1)).unwrap().address
291                };
292                let next_node_socket =
293                    TcpStream::connect(next_node_addr).await?;
294                let (_, next_node_writer) = io::split(next_node_socket);
295                let mut next_node_sink = FramedWrite::new(
296                    next_node_writer,
297                    MessageCodec::<ControlMsg>::new(),
298                );
299                let ready_msg =
300                    Message::new(0, node_id, node_id + 1, ControlMsg::Ready);
301                next_node_sink.send(ready_msg).await?;
302            }
303            let (network, read_streams, kill_notifier) =
304                client_join_handle.await.unwrap()?;
305            // assert that we joined in the right order (kv node id must
306            // match client node id)
307            assert_eq!(node_id, { network.lock().await.id });
308
309            // return the newly registered network
310            Ok((network, read_streams, kill_notifier))
311        }
312    }
313
314    /// Waits and accepts any connection from newly started `Client`s until
315    /// this `Client` has connected to all nodes.  When a new `Client` connects
316    /// to this `Client`, we add the [`Connection`] to its directory so we can
317    /// later send messages to it if we want to.
318    ///
319    /// [`Connection`]: struct.Connection.html
320    async fn accept_new_connections(
321        &mut self,
322        mut listener: TcpListener,
323        num_clients: usize,
324    ) -> Result<Vec<FramedStream<RT>>, LiquidError> {
325        let accepted_type = self.network_name.clone();
326        let mut curr_clients = self.directory.len() + 1;
327        let mut streams = vec![];
328        loop {
329            if num_clients == curr_clients {
330                return Ok(streams);
331            }
332            // wait on connections from new clients
333            let (socket, _) = listener.accept().await?;
334            let (reader, writer) = io::split(socket);
335            let mut stream =
336                FramedRead::new(reader, MessageCodec::<ControlMsg>::new());
337            let sink = FramedWrite::new(writer, MessageCodec::<RT>::new());
338            // read the introduction message from the new client
339            let intro = message::read_msg(&mut stream).await?;
340            let (address, network_name) = if let ControlMsg::Introduction {
341                address,
342                network_name,
343            } = intro.msg
344            {
345                (address, network_name)
346            } else {
347                // we should only receive `ControlMsg::Introduction` msgs here
348                return Err(LiquidError::UnexpectedMessage);
349            };
350
351            if accepted_type != network_name {
352                // we only want to connect with other clients that are the same
353                // type as us
354                return Err(LiquidError::UnexpectedMessage);
355            }
356
357            // increment the message id and check if there was an existing
358            // connection
359            self.msg_id = increment_msg_id(self.msg_id, intro.msg_id);
360            let is_existing_conn =
361                self.directory.contains_key(&intro.sender_id);
362
363            if is_existing_conn {
364                return Err(existing_conn_err(stream, sink));
365            }
366
367            // Add the connection with the new client to this directory
368            let conn = Connection { address, sink };
369            self.directory.insert(intro.sender_id, conn);
370            // NOTE: Not unsafe because message codec has no fields and
371            // can be converted to a different type without losing meaning
372            let new_stream = unsafe {
373                std::mem::transmute::<FramedStream<ControlMsg>, FramedStream<RT>>(
374                    stream,
375                )
376            };
377            streams.push(new_stream);
378            info!(
379                "Connected to id: {:#?} at address: {:#?}",
380                intro.sender_id, address
381            );
382            curr_clients += 1;
383        }
384    }
385
386    /// Connects this `Client` with the `Client` running at the given
387    /// `(id, IP:Port)`. After connecting, adds the [`Connection`] to the other
388    /// `Client` to our directory for sending messages. The returned
389    /// `FramedStream<RT>` is used for reading messages via the `Stream` trait.
390    ///
391    /// [`Connection`]: struct.Connection.html
392    #[allow(clippy::map_entry)] // clippy is being dumb
393    async fn connect(
394        &mut self,
395        client_id: usize,
396        client_addr: SocketAddr,
397    ) -> Result<FramedStream<RT>, LiquidError> {
398        // Connect to the given client
399        let stream = TcpStream::connect(&client_addr).await?;
400        let (reader, writer) = io::split(stream);
401        let stream = FramedRead::new(reader, MessageCodec::<RT>::new());
402        let mut sink =
403            FramedWrite::new(writer, MessageCodec::<ControlMsg>::new());
404
405        // Make the connection struct which holds the sink for sending msgs
406        if self.directory.contains_key(&client_id) {
407            Err(existing_conn_err(stream, sink))
408        } else {
409            sink.send(Message::new(
410                self.msg_id,
411                self.id,
412                0,
413                ControlMsg::Introduction {
414                    address: self.address,
415                    network_name: self.network_name.clone(),
416                },
417            ))
418            .await?;
419            // NOTE: Not unsafe because message codec has no fields and
420            // can be converted to a different type without losing meaning
421            let sink = unsafe {
422                std::mem::transmute::<FramedSink<ControlMsg>, FramedSink<RT>>(
423                    sink,
424                )
425            };
426            let conn = Connection {
427                address: client_addr,
428                sink,
429            };
430            info!(
431                "Connected to id: {:#?} at address: {:#?}",
432                client_id, client_addr
433            );
434            // Add the connection to our directory
435            self.directory.insert(client_id, conn);
436            // send the client our id and address so they can add us to
437            // their directory
438            self.msg_id += 1;
439
440            Ok(stream)
441        }
442    }
443
444    /// Send the given `message` to a `Client` with the given `target_id`.
445    /// Id's are automatically assigned by a [`Server`] during the registration
446    /// period based on the order of connections.
447    ///
448    /// [`Server`]: struct.Server.html
449    pub async fn send_msg(
450        &mut self,
451        target_id: usize,
452        message: RT,
453    ) -> Result<(), LiquidError> {
454        let m = Message::new(self.msg_id, self.id, target_id, message);
455        message::send_msg(target_id, m, &mut self.directory).await?;
456        debug!("sent a message with id, {}", self.msg_id);
457        self.msg_id += 1;
458        Ok(())
459    }
460
461    /// Broadcast the given `message` to all currently connected clients
462    pub async fn broadcast(&mut self, message: RT) -> Result<(), LiquidError> {
463        let d: Vec<usize> = self.directory.iter().map(|(k, _)| *k).collect();
464        for k in d {
465            self.send_msg(k, message.clone()).await?;
466        }
467        Ok(())
468    }
469
470    /// Spawns a `tokio` task that will handle receiving [`ControlMsg::Kill`]
471    /// messages from the [`Server`]
472    ///
473    /// [`Server`]: struct.Server.html
474    /// [`ControlMsg::Kill`]: enum.ControlMsg.html#variant.Kill
475    fn recv_server_msg(
476        mut reader: FramedStream<ControlMsg>,
477        notifier: Arc<Notify>,
478    ) {
479        tokio::spawn(async move {
480            let kill_msg: Message<ControlMsg> =
481                message::read_msg(&mut reader).await.unwrap();
482            match &kill_msg.msg {
483                ControlMsg::Kill => {
484                    notifier.notify();
485                    Ok(())
486                }
487                _ => Err(LiquidError::UnexpectedMessage),
488            }
489        });
490    }
491}