use crate::InvocationError;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
pub struct IntermediateTransport {
stream: TcpStream,
init_sent: bool,
}
impl IntermediateTransport {
pub async fn connect(addr: &str) -> Result<Self, InvocationError> {
let stream = TcpStream::connect(addr).await?;
Ok(Self {
stream,
init_sent: false,
})
}
pub fn from_stream(stream: TcpStream) -> Self {
Self {
stream,
init_sent: false,
}
}
pub async fn send(&mut self, data: &[u8]) -> Result<(), InvocationError> {
if !self.init_sent {
self.stream.write_all(&[0xee, 0xee, 0xee, 0xee]).await?;
self.init_sent = true;
}
let len = (data.len() as u32).to_le_bytes();
self.stream.write_all(&len).await?;
self.stream.write_all(data).await?;
Ok(())
}
pub async fn recv(&mut self) -> Result<Vec<u8>, InvocationError> {
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).await?;
let raw = i32::from_le_bytes(len_buf);
if raw < 0 {
return Err(InvocationError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
format!("transport error: {raw}"),
)));
}
let len = raw as usize;
let mut buf = vec![0u8; len];
self.stream.read_exact(&mut buf).await?;
Ok(buf)
}
pub fn into_inner(self) -> TcpStream {
self.stream
}
}
pub struct PaddedIntermediateTransport {
stream: TcpStream,
init_sent: bool,
}
impl PaddedIntermediateTransport {
pub async fn connect(addr: &str) -> Result<Self, InvocationError> {
let stream = TcpStream::connect(addr).await?;
Ok(Self {
stream,
init_sent: false,
})
}
pub fn from_stream(stream: TcpStream) -> Self {
Self {
stream,
init_sent: false,
}
}
pub async fn send(&mut self, data: &[u8]) -> Result<(), InvocationError> {
if !self.init_sent {
self.stream.write_all(&[0xdd, 0xdd, 0xdd, 0xdd]).await?;
self.init_sent = true;
}
let mut pad_len_buf = [0u8; 1];
getrandom::getrandom(&mut pad_len_buf)
.map_err(|_| InvocationError::Deserialize("getrandom failed".into()))?;
let pad_len = (pad_len_buf[0] & 0x0f) as usize;
let total_len = (data.len() + pad_len) as u32;
self.stream.write_all(&total_len.to_le_bytes()).await?;
self.stream.write_all(data).await?;
if pad_len > 0 {
let mut pad = vec![0u8; pad_len];
getrandom::getrandom(&mut pad)
.map_err(|_| InvocationError::Deserialize("getrandom failed".into()))?;
self.stream.write_all(&pad).await?;
}
Ok(())
}
pub async fn recv(&mut self) -> Result<Vec<u8>, InvocationError> {
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).await?;
let raw = i32::from_le_bytes(len_buf);
if raw < 0 {
return Err(InvocationError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
format!("transport error: {raw}"),
)));
}
let total_len = raw as usize;
let mut buf = vec![0u8; total_len];
self.stream.read_exact(&mut buf).await?;
if buf.len() >= 24 {
let pad = (buf.len() - 24) % 16;
buf.truncate(buf.len() - pad);
}
Ok(buf)
}
pub fn into_inner(self) -> TcpStream {
self.stream
}
}
pub struct FullTransport {
stream: TcpStream,
send_seqno: u32,
recv_seqno: u32,
}
impl FullTransport {
pub async fn connect(addr: &str) -> Result<Self, InvocationError> {
let stream = TcpStream::connect(addr).await?;
Ok(Self {
stream,
send_seqno: 0,
recv_seqno: 0,
})
}
pub fn from_stream(stream: TcpStream) -> Self {
Self {
stream,
send_seqno: 0,
recv_seqno: 0,
}
}
pub async fn send(&mut self, data: &[u8]) -> Result<(), InvocationError> {
let total_len = (data.len() + 12) as u32; let seq = self.send_seqno;
self.send_seqno = self.send_seqno.wrapping_add(1);
let mut packet = Vec::with_capacity(total_len as usize);
packet.extend_from_slice(&total_len.to_le_bytes());
packet.extend_from_slice(&seq.to_le_bytes());
packet.extend_from_slice(data);
let crc = crc32_ieee(&packet);
packet.extend_from_slice(&crc.to_le_bytes());
self.stream.write_all(&packet).await?;
Ok(())
}
pub async fn recv(&mut self) -> Result<Vec<u8>, InvocationError> {
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).await?;
let raw = i32::from_le_bytes(len_buf);
if raw < 0 {
return Err(InvocationError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
format!("transport error: {raw}"),
)));
}
let total_len = raw as usize;
if total_len < 12 {
return Err(InvocationError::Deserialize(
"Full transport: packet too short".into(),
));
}
let mut rest = vec![0u8; total_len - 4];
self.stream.read_exact(&mut rest).await?;
let (body, crc_bytes) = rest.split_at(rest.len() - 4);
let expected_crc = u32::from_le_bytes(crc_bytes.try_into().unwrap());
let mut check_input = len_buf.to_vec();
check_input.extend_from_slice(body);
let actual_crc = crc32_ieee(&check_input);
if actual_crc != expected_crc {
return Err(InvocationError::Deserialize(format!(
"Full transport: CRC mismatch (got {actual_crc:#010x}, expected {expected_crc:#010x})"
)));
}
let recv_seq = u32::from_le_bytes(body[..4].try_into().unwrap());
if recv_seq != self.recv_seqno {
return Err(InvocationError::Deserialize(format!(
"Full transport: seq_no mismatch (got {recv_seq}, expected {})",
self.recv_seqno
)));
}
self.recv_seqno = self.recv_seqno.wrapping_add(1);
Ok(body[4..].to_vec())
}
pub fn into_inner(self) -> TcpStream {
self.stream
}
}
pub(crate) fn crc32_ieee(data: &[u8]) -> u32 {
const POLY: u32 = 0xedb88320;
let mut crc: u32 = 0xffffffff;
for &byte in data {
let mut b = byte as u32;
for _ in 0..8 {
let mix = (crc ^ b) & 1;
crc >>= 1;
if mix != 0 {
crc ^= POLY;
}
b >>= 1;
}
}
crc ^ 0xffffffff
}