use crate::error::LiquidError;
use crate::network::Connection;
use crate::{BYTES_PER_KIB, MAX_FRAME_LEN_FRACTION};
use bincode::{deserialize, serialize};
use bytes::{Bytes, BytesMut};
use futures::SinkExt;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use sysinfo::{RefreshKind, System, SystemExt};
use tokio::io::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::stream::StreamExt;
use tokio_util::codec::{
Decoder, Encoder, FramedRead, FramedWrite, LengthDelimitedCodec,
};
pub(crate) type FramedStream<T> =
FramedRead<ReadHalf<TcpStream>, MessageCodec<T>>;
pub(crate) type FramedSink<T> =
FramedWrite<WriteHalf<TcpStream>, MessageCodec<T>>;
#[derive(Serialize, Deserialize, Debug)]
pub struct Message<T> {
pub msg_id: usize,
pub sender_id: usize,
pub target_id: usize,
pub msg: T,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum ControlMsg {
Directory { dir: Vec<(usize, SocketAddr)> },
Introduction {
address: SocketAddr,
network_name: String,
},
Kill,
Ready,
}
impl<T> Message<T> {
pub fn new(
msg_id: usize,
sender_id: usize,
target_id: usize,
msg: T,
) -> Self {
Message {
msg_id,
sender_id,
target_id,
msg,
}
}
}
#[derive(Debug)]
pub struct MessageCodec<T> {
phantom: std::marker::PhantomData<T>,
pub(crate) codec: LengthDelimitedCodec,
}
impl<T> MessageCodec<T> {
pub(crate) fn new() -> Self {
let memo_info_kind = RefreshKind::new().with_memory();
let sys = System::new_with_specifics(memo_info_kind);
let total_memory = sys.get_total_memory() as f64;
let max_frame_len =
(total_memory * BYTES_PER_KIB * MAX_FRAME_LEN_FRACTION) as usize;
let codec = LengthDelimitedCodec::builder()
.max_frame_length(max_frame_len)
.new_codec();
MessageCodec {
phantom: std::marker::PhantomData,
codec,
}
}
}
impl<T: DeserializeOwned> Decoder for MessageCodec<T> {
type Item = Message<T>;
type Error = LiquidError;
fn decode(
&mut self,
src: &mut BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
match self.codec.decode(src)? {
Some(data) => Ok(Some(deserialize(&data)?)),
None => Ok(None),
}
}
}
impl<T: Serialize> Encoder<Message<T>> for MessageCodec<T> {
type Error = LiquidError;
fn encode(
&mut self,
item: Message<T>,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
let serialized = serialize(&item)?;
Ok(self.codec.encode(Bytes::from(serialized), dst)?)
}
}
pub(crate) async fn read_msg<T: DeserializeOwned>(
reader: &mut FramedStream<T>,
) -> Result<Message<T>, LiquidError> {
match reader.next().await {
None => Err(LiquidError::StreamClosed),
Some(x) => Ok(x?),
}
}
pub(crate) async fn send_msg<T: Serialize>(
target_id: usize,
message: Message<T>,
directory: &mut HashMap<usize, Connection<T>>,
) -> Result<(), LiquidError> {
match directory.get_mut(&target_id) {
None => Err(LiquidError::UnknownId),
Some(conn) => {
conn.sink.send(message).await?;
Ok(())
}
}
}