liquid_ml/network/
message.rs1use crate::error::LiquidError;
4use crate::network::Connection;
5use crate::{BYTES_PER_KIB, MAX_FRAME_LEN_FRACTION};
6use bincode::{deserialize, serialize};
7use bytes::{Bytes, BytesMut};
8use futures::SinkExt;
9use serde::de::DeserializeOwned;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::SocketAddr;
13use sysinfo::{RefreshKind, System, SystemExt};
14use tokio::io::{ReadHalf, WriteHalf};
15use tokio::net::TcpStream;
16use tokio::stream::StreamExt;
17use tokio_util::codec::{
18 Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec,
19};
20
21pub(crate) type FramedStream<T> =
23 FramedRead<ReadHalf<TcpStream>, MessageCodec<T>>;
24pub(crate) type FramedSink<T> =
26 FramedWrite<WriteHalf<TcpStream>, MessageCodec<T>>;
27
28#[derive(Serialize, Deserialize, Debug)]
31pub struct Message<T> {
32 pub msg_id: usize,
34 pub sender_id: usize,
36 pub target_id: usize,
38 pub msg: T,
40}
41
42#[derive(Serialize, Deserialize, Debug, Clone)]
48pub enum ControlMsg {
49 Directory { dir: Vec<(usize, SocketAddr)> },
56 Introduction {
59 address: SocketAddr,
60 network_name: String,
61 },
62 Kill,
68 Ready,
71}
72
73impl<T> Message<T> {
74 pub fn new(
76 msg_id: usize,
77 sender_id: usize,
78 target_id: usize,
79 msg: T,
80 ) -> Self {
81 Message {
82 msg_id,
83 sender_id,
84 target_id,
85 msg,
86 }
87 }
88}
89
90#[derive(Debug)]
96pub struct MessageCodec<T> {
97 phantom: std::marker::PhantomData<T>,
98 pub(crate) codec: LengthDelimitedCodec,
99}
100
101impl<T> MessageCodec<T> {
102 pub(crate) fn new() -> Self {
105 let memo_info_kind = RefreshKind::new().with_memory();
106 let sys = System::new_with_specifics(memo_info_kind);
107 let total_memory = sys.get_total_memory() as f64;
108 let max_frame_len =
109 (total_memory * BYTES_PER_KIB * MAX_FRAME_LEN_FRACTION) as usize;
110 let codec = LengthDelimitedCodec::builder()
111 .max_frame_length(max_frame_len)
112 .new_codec();
113 MessageCodec {
114 phantom: std::marker::PhantomData,
115 codec,
116 }
117 }
118}
119
120impl<T: DeserializeOwned> Decoder for MessageCodec<T> {
121 type Item = Message<T>;
122 type Error = LiquidError;
123 fn decode(
127 &mut self,
128 src: &mut BytesMut,
129 ) -> Result<Option<Self::Item>, Self::Error> {
130 match self.codec.decode(src)? {
131 Some(data) => Ok(Some(deserialize(&data)?)),
132 None => Ok(None),
133 }
134 }
135}
136
137impl<T: Serialize> Encoder<Message<T>> for MessageCodec<T> {
138 type Error = LiquidError;
139 fn encode(
143 &mut self,
144 item: Message<T>,
145 dst: &mut BytesMut,
146 ) -> Result<(), Self::Error> {
147 let serialized = serialize(&item)?;
148 Ok(self.codec.encode(Bytes::from(serialized), dst)?)
149 }
150}
151
152pub(crate) async fn read_msg<T: DeserializeOwned>(
154 reader: &mut FramedStream<T>,
155) -> Result<Message<T>, LiquidError> {
156 match reader.next().await {
157 None => Err(LiquidError::StreamClosed),
158 Some(x) => Ok(x?),
159 }
160}
161
162pub(crate) async fn send_msg<T: Serialize>(
165 target_id: usize,
166 message: Message<T>,
167 directory: &mut HashMap<usize, Connection<T>>,
168) -> Result<(), LiquidError> {
169 match directory.get_mut(&target_id) {
170 None => Err(LiquidError::UnknownId),
171 Some(conn) => {
172 conn.sink.send(message).await?;
173 Ok(())
174 }
175 }
176}