use bincode::{Decode, Encode};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::{
auth::{AuthMessage, AuthVerification},
protocol::{MAX_MESSAGE_SIZE, WireError},
session::{Conversation, Message},
};
#[repr(u8)]
#[derive(Debug, Clone, Copy)]
pub enum MessageType {
Auth = 0x01,
AuthVerification = 0x02,
Chat = 0x03,
Timestamp = 0x04,
}
impl TryFrom<u8> for MessageType {
type Error = WireError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0x01 => Ok(MessageType::Auth),
0x02 => Ok(MessageType::AuthVerification),
0x03 => Ok(MessageType::Chat),
0x04 => Ok(MessageType::Timestamp),
_ => Err(WireError::InvalidFormat),
}
}
}
pub struct WireProtocol<S> {
stream: S,
conversation: Option<Conversation>,
}
impl<S> WireProtocol<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(stream: S) -> Self {
Self {
stream,
conversation: None,
}
}
pub fn set_conversation(&mut self, conversation: Conversation) {
self.conversation = Some(conversation);
}
async fn send_message<T: Encode>(
&mut self,
msg_type: MessageType,
data: &T,
) -> Result<(), WireError> {
let payload = bincode::encode_to_vec(data, bincode::config::standard())
.map_err(|_| WireError::InvalidFormat)?;
self.send_raw_message(msg_type, &payload).await
}
async fn receive_message<T: Decode<()>>(
&mut self,
expected_type: MessageType,
) -> Result<T, WireError> {
let (msg_type, payload) = self.receive_raw_message().await?;
if msg_type as u8 != expected_type as u8 {
return Err(WireError::InvalidFormat);
}
bincode::decode_from_slice(&payload, bincode::config::standard())
.map(|(result, _)| result)
.map_err(|_| WireError::InvalidFormat)
}
pub async fn send_auth_message(&mut self, message: &AuthMessage) -> Result<(), WireError> {
self.send_message(MessageType::Auth, message).await
}
pub async fn receive_auth_message(&mut self) -> Result<AuthMessage, WireError> {
self.receive_message(MessageType::Auth).await
}
pub async fn send_auth_verification(
&mut self,
verification: &AuthVerification,
) -> Result<(), WireError> {
self.send_message(MessageType::AuthVerification, verification)
.await
}
pub async fn receive_auth_verification(&mut self) -> Result<AuthVerification, WireError> {
self.receive_message(MessageType::AuthVerification).await
}
pub async fn send_timestamp(&mut self, timestamp: u64) -> Result<(), WireError> {
self.send_message(MessageType::Timestamp, ×tamp).await
}
pub async fn receive_timestamp(&mut self) -> Result<u64, WireError> {
self.receive_message(MessageType::Timestamp).await
}
pub async fn send_text_message(&mut self, content: &str) -> Result<(), WireError> {
let conversation = self.conversation.as_mut().ok_or(WireError::InvalidFormat)?;
let message = conversation.create_text_message(content);
self.send_message(MessageType::Chat, &message).await
}
pub async fn send_image_message(&mut self, image_data: &[u8]) -> Result<(), WireError> {
let conversation = self.conversation.as_mut().ok_or(WireError::InvalidFormat)?;
let message = conversation.create_image_message(image_data);
self.send_message(MessageType::Chat, &message).await
}
pub async fn receive_chat_message(&mut self) -> Result<(Vec<u8>, u8), WireError> {
let message: Message = self.receive_message(MessageType::Chat).await?;
let conversation = self.conversation.as_ref().ok_or(WireError::InvalidFormat)?;
let content = conversation.decrypt_message(&message)?;
Ok((content, message.content_type))
}
async fn send_raw_message(
&mut self,
msg_type: MessageType,
payload: &[u8],
) -> Result<(), WireError> {
if payload.len() > MAX_MESSAGE_SIZE {
return Err(WireError::MessageTooLarge(payload.len()));
}
let type_bytes = [msg_type as u8];
self.stream.write_all(&type_bytes).await?;
let len_bytes = (payload.len() as u32).to_be_bytes();
self.stream.write_all(&len_bytes).await?;
self.stream.write_all(payload).await?;
self.stream.flush().await?;
Ok(())
}
async fn receive_raw_message(&mut self) -> Result<(MessageType, Vec<u8>), WireError> {
let mut type_buf = [0u8; 1];
self.stream.read_exact(&mut type_buf).await?;
let msg_type = MessageType::try_from(type_buf[0])?;
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).await?;
let payload_len = u32::from_be_bytes(len_buf) as usize;
if payload_len > MAX_MESSAGE_SIZE {
return Err(WireError::MessageTooLarge(payload_len));
}
let mut payload = vec![0u8; payload_len];
self.stream.read_exact(&mut payload).await?;
Ok((msg_type, payload))
}
pub fn stream(&self) -> &S {
&self.stream
}
pub fn into_stream(self) -> S {
self.stream
}
}