liquid_ml/network/client.rs
1//! Represents a client node in a distributed system, with implementations
2//! provided for `LiquidML` use cases.
3use crate::error::LiquidError;
4use crate::network::{
5 existing_conn_err, increment_msg_id, message, Connection, ControlMsg,
6 FramedSink, FramedStream, Message, MessageCodec,
7};
8use futures::{
9 stream::{self, SelectAll},
10 SinkExt,
11};
12use log::{debug, info};
13use serde::de::DeserializeOwned;
14use serde::Serialize;
15use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tokio::io;
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::{Mutex, Notify};
21use tokio_util::codec::{FramedRead, FramedWrite};
22
23/// Represents a `Client` node in a distributed system that is generic for
24/// type `T`, where `T` is the types of messages that can be sent between
25/// `Client`s. Allows directed communication to any other node that shares the
26/// `Client`'s `client_type`, which enables increased concurrency due to
27/// decreased lock contention.
28#[derive(Debug)]
29pub struct Client<T> {
30 /// The `id` of this `Client`, assigned by the [`Server`] on startup
31 /// and is monotonically increasing based on the order of connections
32 ///
33 /// [`Server`]: struct.Server.html
34 pub id: usize,
35 /// The number of `Client`s in the network
36 pub num_nodes: usize,
37 /// The `address` of this `Client`
38 pub address: SocketAddr,
39 /// The id of the current message
40 pub(crate) msg_id: usize,
41 /// A directory which is a map of client id to the [`Connection`] with that
42 /// `Client`
43 ///
44 /// [`Connection`]: struct.Connection.html
45 pub(crate) directory: HashMap<usize, Connection<T>>,
46 /// The connection to the [`Server`](struct.Server.html)
47 server: Connection<ControlMsg>,
48 /// The name of the network this `Client` will connect to. This is so that,
49 /// for example, two different communication networks of
50 /// `Client<DistributedDFMsg>` can be created so that separate
51 /// `DistributedDataFrame`s only talk to themselves.
52 ///
53 /// This allows increased concurrency since each `Client` owned by
54 /// different components have their own `Mutex` around them, instead of a
55 /// single `Client` with one `Mutex`.
56 network_name: String,
57}
58
59// TODO: remove `DeserializeOwned + 'static`
60impl<RT: Send + Sync + DeserializeOwned + Serialize + Clone + 'static>
61 Client<RT>
62{
63 /// Create a new [`Client`] and connect to all other nodes in the network
64 /// with the given `network_name`. If you wish to create multiple networks
65 /// **and** preserve the `node_id`s assigned by the [`Server`], you should
66 /// check out the [register_network] method
67 ///
68 /// # Parameters
69 /// - `server_addr`: The address of the [`Server`] in `IP:Port` format
70 /// - `my_ip`: The `IP` of this [`Client`]
71 /// - `my_port`: An optional port for this [`Client`] to listen for new
72 /// connections. If its `None`, uses the OS to randomly assign
73 /// a port.
74 /// - `num_nodes`: The number of nodes in the network.
75 /// - `network_name`: The name of the network to connect with, will only
76 /// connect with other `Client`s with the same
77 /// `network_name`
78 /// # Returned Values
79 /// This function returns a tuple of three things, the first element is the
80 /// `Client`, which can then be used to send messages to any other node
81 /// with the `send_msg` method. The second element is a struct that
82 /// implements [`StreamExt`] and combines the streams of all messages from
83 /// all other nodes (unordered), which you can use to easily process
84 /// messages like this:
85 ///
86 /// ```ignore
87 /// while let Some(Ok(msg)) = streams.next().await {
88 /// // ... process the message here according to your use case
89 /// }
90 /// ```
91 ///
92 /// The third element is a notifier that will only be notified when the
93 /// [`Server`] sends a `ControlMsg::Kill` message to the `Client`.
94 ///
95 /// [`Client`]: struct.Client.html
96 /// [`Server`]: struct.Server.html
97 /// [`StreamExt`]: https://docs.rs/futures/0.3.4/futures/stream/trait.StreamExt.html
98 /// [`ControlMsg::Directory`]: enum.ControlMsg.html#variant.Directory
99 /// [`ControlMsg::Introduction`]: enum.ControlMsg.html#variant.Introduction
100 /// [`ControlMsg::Kill`]: enum.ControlMsg.html#variant.Kill
101 /// [`register_network`]: struct.Client.html#method.register_network
102 pub async fn new(
103 server_addr: String,
104 my_ip: String,
105 my_port: Option<String>,
106 num_nodes: usize,
107 network_name: String,
108 ) -> Result<
109 (Arc<Mutex<Self>>, SelectAll<FramedStream<RT>>, Arc<Notify>),
110 LiquidError,
111 > {
112 // Setup a TCPListener
113 let listener;
114 let my_address: SocketAddr = match my_port {
115 Some(port) => {
116 let addr = format!("{}:{}", my_ip, port);
117 listener = TcpListener::bind(&addr).await?;
118 addr.parse().unwrap()
119 }
120 None => {
121 let addr = format!("{}:0", my_ip);
122 listener = TcpListener::bind(&addr).await?;
123 listener.local_addr()?.to_string().parse().unwrap()
124 }
125 };
126 // Connect to the server
127 let server_stream = TcpStream::connect(server_addr).await?;
128 let server_address = server_stream.peer_addr().unwrap();
129 let (reader, writer) = io::split(server_stream);
130 let mut stream = FramedRead::new(reader, MessageCodec::new());
131 let sink = FramedWrite::new(writer, MessageCodec::new());
132 let mut server = Connection {
133 address: server_address,
134 sink,
135 };
136 // Tell the server our address and type
137 server
138 .sink
139 .send(Message::new(
140 0,
141 0,
142 0,
143 ControlMsg::Introduction {
144 address: my_address,
145 network_name: network_name.to_string(),
146 },
147 ))
148 .await?;
149 // Server responds with the addresses of all currently connected clients
150 let dir_msg = message::read_msg(&mut stream).await?;
151 let dir = if let ControlMsg::Directory { dir } = dir_msg.msg {
152 dir
153 } else {
154 return Err(LiquidError::UnexpectedMessage);
155 };
156
157 info!(
158 "Client in network {} got id {} running at address {}",
159 network_name, dir_msg.target_id, &my_address
160 );
161
162 // initialize `self`
163 let mut c = Client {
164 id: dir_msg.target_id,
165 address: my_address,
166 msg_id: dir_msg.msg_id + 1,
167 directory: HashMap::new(),
168 num_nodes,
169 server,
170 network_name: network_name.to_string(),
171 };
172
173 // Connect to all the currently existing clients
174 let mut existing_conns = vec![];
175 // note this is done serially and could be done concurrently, but
176 // it doesn't make a difference since there will only ever be a
177 // (relatively) small number of nodes
178 for (id, addr) in dir.into_iter() {
179 existing_conns.push(c.connect(id, addr).await?);
180 }
181
182 // Listen for further messages from the Server, e.g. `Kill` messages
183 let kill_notifier = Arc::new(Notify::new());
184 Client::<ControlMsg>::recv_server_msg(stream, kill_notifier.clone());
185 // block until all the other clients start up and connect to us
186 let new_conns =
187 Client::accept_new_connections(&mut c, listener, num_nodes).await?;
188 let read_streams = stream::select_all(
189 existing_conns.into_iter().chain(new_conns.into_iter()),
190 );
191
192 let concurrent_client = Arc::new(Mutex::new(c));
193 Ok((concurrent_client, read_streams, kill_notifier))
194 }
195
196 /// Given an already connected `Client` of any type, register a new network
197 /// with the given `network_name` that will create a new network of
198 /// `Client`s that preserve `node_id`s across all nodes by connecting in
199 /// the same order as the `parent`. The new `Client` can only talk to
200 /// `Client`s with the same `network_name` as the new network is
201 /// independent of the `parent`.
202 ///
203 /// The tuple returned is the same as in the `Client::new` function.
204 pub async fn register_network<
205 T: Send + Sync + DeserializeOwned + Serialize + Clone + 'static,
206 >(
207 parent: Arc<Mutex<Self>>,
208 network_name: String,
209 ) -> Result<
210 (
211 Arc<Mutex<Client<T>>>,
212 SelectAll<FramedStream<T>>,
213 Arc<Notify>,
214 ),
215 LiquidError,
216 > {
217 let (server_addr, my_ip, node_id, listen_addr, num_nodes) = {
218 let unlocked = parent.lock().await;
219 let node_id = unlocked.id;
220 let server_addr = unlocked.server.address.to_string();
221 let my_ip = unlocked.address.ip().to_string();
222 let num_nodes = unlocked.num_nodes;
223 (server_addr, my_ip, node_id, unlocked.address, num_nodes)
224 };
225 if node_id == 1 {
226 // connect our client right away since we want to be node 1
227 let jh = tokio::spawn(async move {
228 Client::<T>::new(
229 server_addr,
230 my_ip,
231 None,
232 num_nodes,
233 network_name,
234 )
235 .await
236 });
237 // Send a ready message to node 2 so that all the other nodes
238 // start connecting to the Server in the correct order
239 let node_2_addr = {
240 let unlocked = parent.lock().await;
241 unlocked.directory.get(&2).unwrap().address
242 };
243 let socket = TcpStream::connect(node_2_addr).await?;
244 let (_, writer) = io::split(socket);
245 let mut sink =
246 FramedWrite::new(writer, MessageCodec::<ControlMsg>::new());
247 let msg = Message::new(0, node_id, 2, ControlMsg::Ready);
248 sink.send(msg).await?;
249 let (network, read_streams, kill_notifier) = jh.await.unwrap()?;
250 assert_eq!(1, { network.lock().await.id });
251 // return the newly registered network
252 Ok((network, read_streams, kill_notifier))
253 } else {
254 // wait to receive a `Ready` message from the node before us
255 // the `parent` passed in
256 let mut listener = TcpListener::bind(listen_addr).await?;
257 let (socket, _) = listener.accept().await?;
258 let (reader, writer) = io::split(socket);
259 let mut stream =
260 FramedRead::new(reader, MessageCodec::<ControlMsg>::new());
261 let mut sink =
262 FramedWrite::new(writer, MessageCodec::<ControlMsg>::new());
263 // wait for the ready message
264 let msg = message::read_msg(&mut stream).await?;
265 //assert_eq!(msg.sender_id, node_id);
266 match msg.msg {
267 ControlMsg::Ready => (),
268 _ => return Err(LiquidError::UnexpectedMessage),
269 };
270 // The node before us has joined the network, it is now time
271 // to connect
272 let client_join_handle = tokio::spawn(async move {
273 Client::<T>::new(
274 server_addr,
275 my_ip,
276 None,
277 num_nodes,
278 network_name,
279 )
280 .await
281 });
282
283 // tell the next node we are ready
284 if node_id < num_nodes {
285 // There is another node after us
286 let msg = Message::new(0, node_id, node_id, ControlMsg::Ready);
287 sink.send(msg).await?;
288 let next_node_addr = {
289 let unlocked = parent.lock().await;
290 unlocked.directory.get(&(node_id + 1)).unwrap().address
291 };
292 let next_node_socket =
293 TcpStream::connect(next_node_addr).await?;
294 let (_, next_node_writer) = io::split(next_node_socket);
295 let mut next_node_sink = FramedWrite::new(
296 next_node_writer,
297 MessageCodec::<ControlMsg>::new(),
298 );
299 let ready_msg =
300 Message::new(0, node_id, node_id + 1, ControlMsg::Ready);
301 next_node_sink.send(ready_msg).await?;
302 }
303 let (network, read_streams, kill_notifier) =
304 client_join_handle.await.unwrap()?;
305 // assert that we joined in the right order (kv node id must
306 // match client node id)
307 assert_eq!(node_id, { network.lock().await.id });
308
309 // return the newly registered network
310 Ok((network, read_streams, kill_notifier))
311 }
312 }
313
314 /// Waits and accepts any connection from newly started `Client`s until
315 /// this `Client` has connected to all nodes. When a new `Client` connects
316 /// to this `Client`, we add the [`Connection`] to its directory so we can
317 /// later send messages to it if we want to.
318 ///
319 /// [`Connection`]: struct.Connection.html
320 async fn accept_new_connections(
321 &mut self,
322 mut listener: TcpListener,
323 num_clients: usize,
324 ) -> Result<Vec<FramedStream<RT>>, LiquidError> {
325 let accepted_type = self.network_name.clone();
326 let mut curr_clients = self.directory.len() + 1;
327 let mut streams = vec![];
328 loop {
329 if num_clients == curr_clients {
330 return Ok(streams);
331 }
332 // wait on connections from new clients
333 let (socket, _) = listener.accept().await?;
334 let (reader, writer) = io::split(socket);
335 let mut stream =
336 FramedRead::new(reader, MessageCodec::<ControlMsg>::new());
337 let sink = FramedWrite::new(writer, MessageCodec::<RT>::new());
338 // read the introduction message from the new client
339 let intro = message::read_msg(&mut stream).await?;
340 let (address, network_name) = if let ControlMsg::Introduction {
341 address,
342 network_name,
343 } = intro.msg
344 {
345 (address, network_name)
346 } else {
347 // we should only receive `ControlMsg::Introduction` msgs here
348 return Err(LiquidError::UnexpectedMessage);
349 };
350
351 if accepted_type != network_name {
352 // we only want to connect with other clients that are the same
353 // type as us
354 return Err(LiquidError::UnexpectedMessage);
355 }
356
357 // increment the message id and check if there was an existing
358 // connection
359 self.msg_id = increment_msg_id(self.msg_id, intro.msg_id);
360 let is_existing_conn =
361 self.directory.contains_key(&intro.sender_id);
362
363 if is_existing_conn {
364 return Err(existing_conn_err(stream, sink));
365 }
366
367 // Add the connection with the new client to this directory
368 let conn = Connection { address, sink };
369 self.directory.insert(intro.sender_id, conn);
370 // NOTE: Not unsafe because message codec has no fields and
371 // can be converted to a different type without losing meaning
372 let new_stream = unsafe {
373 std::mem::transmute::<FramedStream<ControlMsg>, FramedStream<RT>>(
374 stream,
375 )
376 };
377 streams.push(new_stream);
378 info!(
379 "Connected to id: {:#?} at address: {:#?}",
380 intro.sender_id, address
381 );
382 curr_clients += 1;
383 }
384 }
385
386 /// Connects this `Client` with the `Client` running at the given
387 /// `(id, IP:Port)`. After connecting, adds the [`Connection`] to the other
388 /// `Client` to our directory for sending messages. The returned
389 /// `FramedStream<RT>` is used for reading messages via the `Stream` trait.
390 ///
391 /// [`Connection`]: struct.Connection.html
392 #[allow(clippy::map_entry)] // clippy is being dumb
393 async fn connect(
394 &mut self,
395 client_id: usize,
396 client_addr: SocketAddr,
397 ) -> Result<FramedStream<RT>, LiquidError> {
398 // Connect to the given client
399 let stream = TcpStream::connect(&client_addr).await?;
400 let (reader, writer) = io::split(stream);
401 let stream = FramedRead::new(reader, MessageCodec::<RT>::new());
402 let mut sink =
403 FramedWrite::new(writer, MessageCodec::<ControlMsg>::new());
404
405 // Make the connection struct which holds the sink for sending msgs
406 if self.directory.contains_key(&client_id) {
407 Err(existing_conn_err(stream, sink))
408 } else {
409 sink.send(Message::new(
410 self.msg_id,
411 self.id,
412 0,
413 ControlMsg::Introduction {
414 address: self.address,
415 network_name: self.network_name.clone(),
416 },
417 ))
418 .await?;
419 // NOTE: Not unsafe because message codec has no fields and
420 // can be converted to a different type without losing meaning
421 let sink = unsafe {
422 std::mem::transmute::<FramedSink<ControlMsg>, FramedSink<RT>>(
423 sink,
424 )
425 };
426 let conn = Connection {
427 address: client_addr,
428 sink,
429 };
430 info!(
431 "Connected to id: {:#?} at address: {:#?}",
432 client_id, client_addr
433 );
434 // Add the connection to our directory
435 self.directory.insert(client_id, conn);
436 // send the client our id and address so they can add us to
437 // their directory
438 self.msg_id += 1;
439
440 Ok(stream)
441 }
442 }
443
444 /// Send the given `message` to a `Client` with the given `target_id`.
445 /// Id's are automatically assigned by a [`Server`] during the registration
446 /// period based on the order of connections.
447 ///
448 /// [`Server`]: struct.Server.html
449 pub async fn send_msg(
450 &mut self,
451 target_id: usize,
452 message: RT,
453 ) -> Result<(), LiquidError> {
454 let m = Message::new(self.msg_id, self.id, target_id, message);
455 message::send_msg(target_id, m, &mut self.directory).await?;
456 debug!("sent a message with id, {}", self.msg_id);
457 self.msg_id += 1;
458 Ok(())
459 }
460
461 /// Broadcast the given `message` to all currently connected clients
462 pub async fn broadcast(&mut self, message: RT) -> Result<(), LiquidError> {
463 let d: Vec<usize> = self.directory.iter().map(|(k, _)| *k).collect();
464 for k in d {
465 self.send_msg(k, message.clone()).await?;
466 }
467 Ok(())
468 }
469
470 /// Spawns a `tokio` task that will handle receiving [`ControlMsg::Kill`]
471 /// messages from the [`Server`]
472 ///
473 /// [`Server`]: struct.Server.html
474 /// [`ControlMsg::Kill`]: enum.ControlMsg.html#variant.Kill
475 fn recv_server_msg(
476 mut reader: FramedStream<ControlMsg>,
477 notifier: Arc<Notify>,
478 ) {
479 tokio::spawn(async move {
480 let kill_msg: Message<ControlMsg> =
481 message::read_msg(&mut reader).await.unwrap();
482 match &kill_msg.msg {
483 ControlMsg::Kill => {
484 notifier.notify();
485 Ok(())
486 }
487 _ => Err(LiquidError::UnexpectedMessage),
488 }
489 });
490 }
491}