liquid_ml/network/
message.rs

1//! Defines messages and codecs used to communicate with the network of nodes
2//! over `TCP`.
3use crate::error::LiquidError;
4use crate::network::Connection;
5use crate::{BYTES_PER_KIB, MAX_FRAME_LEN_FRACTION};
6use bincode::{deserialize, serialize};
7use bytes::{Bytes, BytesMut};
8use futures::SinkExt;
9use serde::de::DeserializeOwned;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::SocketAddr;
13use sysinfo::{RefreshKind, System, SystemExt};
14use tokio::io::{ReadHalf, WriteHalf};
15use tokio::net::TcpStream;
16use tokio::stream::StreamExt;
17use tokio_util::codec::{
18    Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec,
19};
20
21/// A buffered and framed message codec for reading messages of type `T`
22pub(crate) type FramedStream<T> =
23    FramedRead<ReadHalf<TcpStream>, MessageCodec<T>>;
24/// A buffered and framed message codec for sending messages of type `T`
25pub(crate) type FramedSink<T> =
26    FramedWrite<WriteHalf<TcpStream>, MessageCodec<T>>;
27
28/// A message that can sent between nodes for communication. The message
29/// is generic for type `T`
30#[derive(Serialize, Deserialize, Debug)]
31pub struct Message<T> {
32    /// The id of this message
33    pub msg_id: usize,
34    /// The id of the sender
35    pub sender_id: usize,
36    /// The id of the node this message is being sent to
37    pub target_id: usize,
38    /// The body of the message
39    pub msg: T,
40}
41
42/// Control messages to facilitate communication with the registration
43/// [`Server`] and other [`Client`]s
44///
45/// [`Server`]: struct.Server.html
46/// [`Client`]: struct.Client.html
47#[derive(Serialize, Deserialize, Debug, Clone)]
48pub enum ControlMsg {
49    /// A directory message sent by the [`Server`] to new [`Client`]s once they
50    /// connect so that they know which other [`Client`]s of that type are
51    /// currently connected
52    ///
53    /// [`Server`]: struct.Server.html
54    /// [`Client`]: struct.Client.html
55    Directory { dir: Vec<(usize, SocketAddr)> },
56    /// An introduction that a new [`Client`] sends to all other existing
57    /// [`Client`]s and the [`Server`]
58    Introduction {
59        address: SocketAddr,
60        network_name: String,
61    },
62    /// A message the [`Server`] sends to [`Client`]s to inform them to shut
63    /// down
64    ///
65    /// [`Server`]: struct.Server.html
66    /// [`Client`]: struct.Client.html
67    Kill,
68    /// A message to notify other [`Client`]s when they are ready to register
69    /// a new [`Client`] type
70    Ready,
71}
72
73impl<T> Message<T> {
74    /// Creates a new `Message`.
75    pub fn new(
76        msg_id: usize,
77        sender_id: usize,
78        target_id: usize,
79        msg: T,
80    ) -> Self {
81        Message {
82            msg_id,
83            sender_id,
84            target_id,
85            msg,
86        }
87    }
88}
89
90/// A message encoder/decoder to help frame messages sent over `TCP`,
91/// particularly in the case of very large messages. Uses a very simple method
92/// of writing the length of the serialized message at the very start of
93/// a frame, followed by the serialized message. When decoding, this length
94/// is used to determine if a full frame has been read.
95#[derive(Debug)]
96pub struct MessageCodec<T> {
97    phantom: std::marker::PhantomData<T>,
98    pub(crate) codec: LengthDelimitedCodec,
99}
100
101impl<T> MessageCodec<T> {
102    /// Creates a new `MessageCodec` with a maximum frame length that is 80%
103    /// of the total memory on this machine.
104    pub(crate) fn new() -> Self {
105        let memo_info_kind = RefreshKind::new().with_memory();
106        let sys = System::new_with_specifics(memo_info_kind);
107        let total_memory = sys.get_total_memory() as f64;
108        let max_frame_len =
109            (total_memory * BYTES_PER_KIB * MAX_FRAME_LEN_FRACTION) as usize;
110        let codec = LengthDelimitedCodec::builder()
111            .max_frame_length(max_frame_len)
112            .new_codec();
113        MessageCodec {
114            phantom: std::marker::PhantomData,
115            codec,
116        }
117    }
118}
119
120impl<T: DeserializeOwned> Decoder for MessageCodec<T> {
121    type Item = Message<T>;
122    type Error = LiquidError;
123    /// Decodes a message by reading the length of the message (at the start of
124    /// a frame) and then reading that many bytes from a buffer to complete the
125    /// frame.
126    fn decode(
127        &mut self,
128        src: &mut BytesMut,
129    ) -> Result<Option<Self::Item>, Self::Error> {
130        match self.codec.decode(src)? {
131            Some(data) => Ok(Some(deserialize(&data)?)),
132            None => Ok(None),
133        }
134    }
135}
136
137impl<T: Serialize> Encoder<Message<T>> for MessageCodec<T> {
138    type Error = LiquidError;
139    /// Encodes a message by writing the length of the serialized message at
140    /// the start of a frame, and then writing that many bytes into a buffer
141    /// to be sent.
142    fn encode(
143        &mut self,
144        item: Message<T>,
145        dst: &mut BytesMut,
146    ) -> Result<(), Self::Error> {
147        let serialized = serialize(&item)?;
148        Ok(self.codec.encode(Bytes::from(serialized), dst)?)
149    }
150}
151
152/// Asynchronously waits to read the next message from the given `reader`
153pub(crate) async fn read_msg<T: DeserializeOwned>(
154    reader: &mut FramedStream<T>,
155) -> Result<Message<T>, LiquidError> {
156    match reader.next().await {
157        None => Err(LiquidError::StreamClosed),
158        Some(x) => Ok(x?),
159    }
160}
161
162/// Send the given `message` to the node with the given `target_id` using
163/// the given `directory`
164pub(crate) async fn send_msg<T: Serialize>(
165    target_id: usize,
166    message: Message<T>,
167    directory: &mut HashMap<usize, Connection<T>>,
168) -> Result<(), LiquidError> {
169    match directory.get_mut(&target_id) {
170        None => Err(LiquidError::UnknownId),
171        Some(conn) => {
172            conn.sink.send(message).await?;
173            Ok(())
174        }
175    }
176}