use crate::shared::{
protocol::{self, client_hello::MonitorInfo, ClientHello, ServerHelloAck},
LengthType, LENGTH_SIZE, PROTOCOL_VERSION,
};
use prost::Message;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::time::{timeout, Duration};
#[derive(Debug)]
pub struct AsyncMessageCodec<S: AsyncRead + AsyncWrite + Send + Unpin> {
stream: S,
buf: Vec<u8>,
length: usize,
partial_read: bool,
}
impl<S: AsyncRead + AsyncWrite + Send + Unpin> AsyncMessageCodec<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
buf: Vec::new(),
length: 0,
partial_read: false,
}
}
pub fn get_stream(&mut self) -> &mut S {
&mut self.stream
}
pub async fn read_message(&mut self) -> std::io::Result<prost::bytes::Bytes> {
let read_timeout = Duration::from_millis(10);
if !self.partial_read {
let mut length_buf = [0; LENGTH_SIZE];
timeout(read_timeout, self.stream.read_exact(&mut length_buf)).await??;
self.length = LengthType::from_be_bytes(length_buf) as usize;
self.buf.resize(self.length, 0);
}
self.partial_read = true;
timeout(read_timeout, self.stream.read_exact(&mut self.buf)).await??;
let bytes = prost::bytes::Bytes::from(std::mem::replace(
&mut self.buf,
Vec::with_capacity(self.length),
));
self.partial_read = false;
Ok(bytes)
}
pub async fn write_message<T: Message>(&mut self, message: T) -> std::io::Result<()> {
let message = message.encode_to_vec();
let mut buf: Vec<u8> = Vec::new(); let length = message.len() as LengthType;
let length_buf = length.to_be_bytes();
assert_eq!(length_buf.len(), LENGTH_SIZE);
buf.extend_from_slice(&length_buf);
buf.extend_from_slice(&message);
self.stream.write_all(&buf).await?;
self.stream.flush().await?;
Ok(())
}
}
pub async fn handshake_client<S>(
messages: &mut AsyncMessageCodec<S>,
monitors: Vec<MonitorInfo>,
) -> std::io::Result<ServerHelloAck>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let os = match std::env::consts::OS {
"linux" => protocol::client_hello::Os::Linux,
"windows" => protocol::client_hello::Os::Windows,
"macos" => protocol::client_hello::Os::Macos,
_ => protocol::client_hello::Os::Unknown,
} as i32;
let os_version = os_info::get().version().to_string();
messages
.write_message(protocol::ClientHello {
protocol_version: PROTOCOL_VERSION,
os,
os_version,
monitors,
})
.await?;
let server_hello = protocol::ServerHelloAck::decode(messages.read_message().await?)?;
Ok(server_hello)
}
pub async fn handshake_server<S>(
messages: &mut AsyncMessageCodec<S>,
supported_protocol_versions: &[u32],
server_hello: ServerHelloAck,
) -> std::io::Result<ClientHello>
where
S: AsyncRead + AsyncWrite + Send + Unpin,
{
let client_hello = protocol::ClientHello::decode(messages.read_message().await?)?;
if !supported_protocol_versions.contains(&client_hello.protocol_version) {
let msg = format!(
"Unsupported client protocol version: {}. Supported versions: {:?}",
client_hello.protocol_version, supported_protocol_versions
);
messages
.write_message(protocol::StatusUpdate {
kind: protocol::status_update::StatusType::Exit as i32,
details: Some(protocol::status_update::Details::Exit(
protocol::status_update::Exit {},
)),
})
.await?;
return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, msg));
}
messages.write_message(server_hello).await?;
Ok(client_hello)
}