use std::io::Cursor;
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use ciborium::{from_reader, into_writer};
use serde::{Serialize, de::DeserializeOwned};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
pub use super::socket::*;
use crate::{
error::Error,
internal_prelude::*,
message::{request::Request, response::Response},
};
pub const PACKET_SIZE: usize = 1280;
pub async fn send_request<T>(message: T, stream: &mut GenericStream) -> Result<(), Error>
where
T: Into<Request>,
T: Serialize + std::fmt::Debug,
{
send_message::<_, Request>(message, stream).await
}
pub async fn send_response<T>(message: T, stream: &mut GenericStream) -> Result<(), Error>
where
T: Into<Response>,
T: Serialize + std::fmt::Debug,
{
send_message::<_, Response>(message, stream).await
}
pub async fn send_message<O, T>(message: O, stream: &mut GenericStream) -> Result<(), Error>
where
O: Into<T>,
T: Serialize + std::fmt::Debug,
{
let message: T = message.into();
debug!("Sending message: {message:#?}",);
let mut payload = Vec::new();
into_writer(&message, &mut payload)
.map_err(|err| Error::MessageSerialization(err.to_string()))?;
send_bytes(&payload, stream).await
}
pub async fn send_bytes(payload: &[u8], stream: &mut GenericStream) -> Result<(), Error> {
let message_size = payload.len() as u64;
let mut header = Vec::new();
WriteBytesExt::write_u64::<BigEndian>(&mut header, message_size).unwrap();
stream
.write_all(&header)
.await
.map_err(|err| Error::IoError("sending request size header".to_string(), err))?;
for chunk in payload.chunks(PACKET_SIZE) {
stream
.write_all(chunk)
.await
.map_err(|err| Error::IoError("sending payload chunk".to_string(), err))?;
}
stream.flush().await?;
Ok(())
}
pub async fn receive_bytes(stream: &mut GenericStream) -> Result<Vec<u8>, Error> {
receive_bytes_with_max_size(stream, None).await
}
pub async fn receive_bytes_with_max_size(
stream: &mut GenericStream,
max_size: Option<usize>,
) -> Result<Vec<u8>, Error> {
let mut header = vec![0; 8];
stream
.read_exact(&mut header)
.await
.map_err(|err| Error::IoError("reading request size header".to_string(), err))?;
let mut header = Cursor::new(header);
let message_size_u64 = ReadBytesExt::read_u64::<BigEndian>(&mut header)?;
if let Some(max_size) = max_size
&& message_size_u64 > max_size as u64
{
error!(
"Client requested message size of {message_size_u64}, but only {max_size} is allowed."
);
return Err(Error::MessageTooBig(message_size_u64 as usize, max_size));
}
if message_size_u64 > (20 * (2u64.pow(20))) {
warn!("Client is sending a large payload: {message_size_u64} bytes.");
}
let message_size = usize::try_from(message_size_u64)
.map_err(|_| Error::MessageTooBig(usize::MAX, usize::MAX))?;
let mut payload_bytes = Vec::with_capacity(message_size);
while payload_bytes.len() < message_size {
let remaining_bytes = message_size - payload_bytes.len();
let mut chunk_buffer: Vec<u8> = if remaining_bytes < PACKET_SIZE {
vec![0; remaining_bytes]
} else {
vec![0; PACKET_SIZE]
};
let received_bytes = stream
.read(&mut chunk_buffer)
.await
.map_err(|err| Error::IoError("reading next chunk".to_string(), err))?;
if received_bytes == 0 {
return Err(Error::Connection(
"Connection went away while receiving payload.".into(),
));
}
payload_bytes.extend_from_slice(&chunk_buffer[0..received_bytes]);
}
Ok(payload_bytes)
}
pub async fn receive_request(stream: &mut GenericStream) -> Result<Request, Error> {
receive_message::<Request>(stream).await
}
pub async fn receive_response(stream: &mut GenericStream) -> Result<Response, Error> {
receive_message::<Response>(stream).await
}
pub async fn receive_message<T: DeserializeOwned + std::fmt::Debug>(
stream: &mut GenericStream,
) -> Result<T, Error> {
let payload_bytes = receive_bytes(stream).await?;
if payload_bytes.is_empty() {
return Err(Error::EmptyPayload);
}
let message: T = from_reader(payload_bytes.as_slice()).map_err(|err| {
if let Ok(value) = from_reader::<ciborium::Value, _>(payload_bytes.as_slice()) {
Error::UnexpectedPayload(value)
} else {
Error::MessageDeserialization(err.to_string())
}
})?;
debug!("Received message: {message:#?}");
Ok(message)
}
#[cfg(test)]
mod test {
use std::time::Duration;
use async_trait::async_trait;
use pretty_assertions::assert_eq;
use tokio::{
net::{TcpListener, TcpStream},
task,
};
use super::*;
use crate::{
message::request::{Request, SendRequest},
network::socket::Stream as PueueStream,
};
#[async_trait]
impl Listener for TcpListener {
async fn accept<'a>(&'a self) -> Result<GenericStream, Error> {
let (stream, _) = self.accept().await?;
Ok(Box::new(stream))
}
}
impl PueueStream for TcpStream {}
#[tokio::test]
async fn test_single_huge_payload() -> Result<(), Error> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let payload = "a".repeat(100_000);
let request: Request = SendRequest {
task_id: 0,
input: payload,
}
.into();
let mut original_bytes = Vec::new();
into_writer(&request, &mut original_bytes).expect("Failed to serialize message.");
let listener: GenericListener = Box::new(listener);
task::spawn(async move {
let mut stream = listener.accept().await.unwrap();
let message_bytes = receive_bytes(&mut stream).await.unwrap();
let message: Request = from_reader(message_bytes.as_slice()).unwrap();
send_request(message, &mut stream).await.unwrap();
});
let mut client: GenericStream = Box::new(TcpStream::connect(&addr).await?);
send_request(request, &mut client).await?;
let response_bytes = receive_bytes(&mut client).await?;
let _message: Request = from_reader(response_bytes.as_slice())
.map_err(|err| Error::MessageDeserialization(err.to_string()))?;
assert_eq!(response_bytes, original_bytes);
Ok(())
}
#[tokio::test]
async fn test_successive_messages() -> Result<(), Error> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let listener: GenericListener = Box::new(listener);
task::spawn(async move {
let mut stream = listener.accept().await.unwrap();
send_request(Request::Status, &mut stream).await.unwrap();
send_request(Request::Remove(vec![0, 2, 3]), &mut stream)
.await
.unwrap();
});
let mut client: GenericStream = Box::new(TcpStream::connect(&addr).await?);
tokio::time::sleep(Duration::from_millis(500)).await;
let message_a = receive_message(&mut client).await.expect("First message");
let message_b = receive_message(&mut client).await.expect("Second message");
assert_eq!(Request::Status, message_a);
assert_eq!(Request::Remove(vec![0, 2, 3]), message_b);
Ok(())
}
#[tokio::test]
async fn test_restricted_payload_size() -> Result<(), Error> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let listener: GenericListener = Box::new(listener);
task::spawn(async move {
let mut stream = listener.accept().await.unwrap();
stream
.write_all(&[128, 0, 0, 0, 0, 0, 0, 0, 0])
.await
.unwrap();
});
let mut client: GenericStream = Box::new(TcpStream::connect(&addr).await?);
tokio::time::sleep(Duration::from_millis(500)).await;
let result = receive_bytes_with_max_size(&mut client, Some(4 * 2usize.pow(20))).await;
assert!(
result.is_err(),
"The payload should be rejected due to large size"
);
Ok(())
}
}