use crate::InvocationError;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
pub struct IntermediateTransport {
stream: TcpStream,
init_sent: bool,
}
impl IntermediateTransport {
pub async fn connect(addr: &str) -> Result<Self, InvocationError> {
let stream = TcpStream::connect(addr).await?;
Ok(Self {
stream,
init_sent: false,
})
}
pub fn from_stream(stream: TcpStream) -> Self {
Self {
stream,
init_sent: false,
}
}
pub async fn send(&mut self, data: &[u8]) -> Result<(), InvocationError> {
if !self.init_sent {
self.stream.write_all(&[0xee, 0xee, 0xee, 0xee]).await?;
self.init_sent = true;
}
let len = (data.len() as u32).to_le_bytes();
self.stream.write_all(&len).await?;
self.stream.write_all(data).await?;
Ok(())
}
pub async fn recv(&mut self) -> Result<Vec<u8>, InvocationError> {
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).await?;
let len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; len];
self.stream.read_exact(&mut buf).await?;
Ok(buf)
}
pub fn into_inner(self) -> TcpStream {
self.stream
}
}
pub struct FullTransport {
stream: TcpStream,
send_seqno: u32,
recv_seqno: u32,
}
impl FullTransport {
pub async fn connect(addr: &str) -> Result<Self, InvocationError> {
let stream = TcpStream::connect(addr).await?;
Ok(Self {
stream,
send_seqno: 0,
recv_seqno: 0,
})
}
pub fn from_stream(stream: TcpStream) -> Self {
Self {
stream,
send_seqno: 0,
recv_seqno: 0,
}
}
pub async fn send(&mut self, data: &[u8]) -> Result<(), InvocationError> {
let total_len = (data.len() + 12) as u32; let seq = self.send_seqno;
self.send_seqno = self.send_seqno.wrapping_add(1);
let mut packet = Vec::with_capacity(total_len as usize);
packet.extend_from_slice(&total_len.to_le_bytes());
packet.extend_from_slice(&seq.to_le_bytes());
packet.extend_from_slice(data);
let crc = crc32_ieee(&packet);
packet.extend_from_slice(&crc.to_le_bytes());
self.stream.write_all(&packet).await?;
Ok(())
}
pub async fn recv(&mut self) -> Result<Vec<u8>, InvocationError> {
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).await?;
let total_len = u32::from_le_bytes(len_buf) as usize;
if total_len < 12 {
return Err(InvocationError::Deserialize(
"Full transport: packet too short".into(),
));
}
let mut rest = vec![0u8; total_len - 4];
self.stream.read_exact(&mut rest).await?;
let (body, crc_bytes) = rest.split_at(rest.len() - 4);
let expected_crc = u32::from_le_bytes(crc_bytes.try_into().unwrap());
let mut check_input = len_buf.to_vec();
check_input.extend_from_slice(body);
let actual_crc = crc32_ieee(&check_input);
if actual_crc != expected_crc {
return Err(InvocationError::Deserialize(format!(
"Full transport: CRC mismatch (got {actual_crc:#010x}, expected {expected_crc:#010x})"
)));
}
let _recv_seq = u32::from_le_bytes(body[..4].try_into().unwrap());
self.recv_seqno = self.recv_seqno.wrapping_add(1);
Ok(body[4..].to_vec())
}
pub fn into_inner(self) -> TcpStream {
self.stream
}
}
pub(crate) fn crc32_ieee(data: &[u8]) -> u32 {
const POLY: u32 = 0xedb88320;
let mut crc: u32 = 0xffffffff;
for &byte in data {
let mut b = byte as u32;
for _ in 0..8 {
let mix = (crc ^ b) & 1;
crc >>= 1;
if mix != 0 {
crc ^= POLY;
}
b >>= 1;
}
}
crc ^ 0xffffffff
}