#![allow(unsafe_code)]
use std::net::SocketAddr;
use std::sync::Mutex;
use bytes::{Bytes, BytesMut};
use wasi::io::poll;
use wasi::sockets::instance_network::instance_network;
use wasi::sockets::network::{
IpAddressFamily, IpSocketAddress, Ipv4SocketAddress, Ipv6SocketAddress,
};
use wasi::sockets::tcp::{InputStream, OutputStream, TcpSocket};
use wasi::sockets::tcp_create_socket::create_tcp_socket;
use crate::errors::CoreError;
use crate::transport::session_transport::SessionTransport;
const MAX_FRAME_BYTES: usize = 4 * 1024 * 1024;
const RECV_BUF_INITIAL_CAPACITY: usize = 64 * 1024;
const RECV_CHUNK: usize = 64 * 1024;
const SHRINK_SLACK_MULT: usize = 4;
pub struct WasiLeg {
output: Mutex<OutputStream>,
read: Mutex<(InputStream, BytesMut)>,
_socket: TcpSocket,
}
unsafe impl Send for WasiLeg {}
unsafe impl Sync for WasiLeg {}
impl WasiLeg {
pub fn connect(remote: SocketAddr) -> Result<Self, CoreError> {
let (family, addr) = ip_socket_address_from_std(remote);
let network = instance_network();
let socket = create_tcp_socket(family)
.map_err(|e| CoreError::NetworkError(format!("create_tcp_socket: {:?}", e)))?;
socket
.start_connect(&network, addr)
.map_err(|e| CoreError::NetworkError(format!("start_connect: {:?}", e)))?;
let pollable = socket.subscribe();
let _ready = poll::poll(&[&pollable]);
let (input, output) = socket
.finish_connect()
.map_err(|e| CoreError::NetworkError(format!("finish_connect: {:?}", e)))?;
Ok(Self {
output: Mutex::new(output),
read: Mutex::new((input, BytesMut::with_capacity(RECV_BUF_INITIAL_CAPACITY))),
_socket: socket,
})
}
}
impl SessionTransport for WasiLeg {
async fn send_bytes(&self, data: &[u8]) -> Result<(), CoreError> {
if data.len() > MAX_FRAME_BYTES {
return Err(CoreError::NetworkError(format!(
"frame too large: {} > {}",
data.len(),
MAX_FRAME_BYTES
)));
}
#[allow(clippy::expect_used)]
let out = self.output.lock().expect("WasiLeg output mutex poisoned");
let len = (data.len() as u32).to_be_bytes();
out.blocking_write_and_flush(&len)
.map_err(|e| CoreError::NetworkError(format!("write length: {:?}", e)))?;
out.blocking_write_and_flush(data)
.map_err(|e| CoreError::NetworkError(format!("write payload: {:?}", e)))?;
Ok(())
}
async fn recv_bytes(&self) -> Result<Bytes, CoreError> {
#[allow(clippy::expect_used)]
let mut guard = self.read.lock().expect("WasiLeg read mutex poisoned");
let (input, accum) = &mut *guard;
let mut len_buf = [0u8; 4];
read_exact(input, &mut len_buf)?;
let len = u32::from_be_bytes(len_buf) as usize;
if len > MAX_FRAME_BYTES {
return Err(CoreError::NetworkError(format!(
"oversized frame from peer: {} > {}",
len, MAX_FRAME_BYTES
)));
}
accum.clear();
let mut filled = 0usize;
while filled < len {
let chunk = (len - filled).min(RECV_CHUNK);
accum.resize(filled + chunk, 0);
read_exact(input, &mut accum[filled..filled + chunk])?;
filled += chunk;
}
let frame = accum.split_to(len).freeze();
if len > RECV_BUF_INITIAL_CAPACITY * SHRINK_SLACK_MULT {
*accum = BytesMut::with_capacity(RECV_BUF_INITIAL_CAPACITY);
}
Ok(frame)
}
}
fn read_exact(input: &InputStream, dest: &mut [u8]) -> Result<(), CoreError> {
let mut filled = 0;
while filled < dest.len() {
let want = (dest.len() - filled) as u64;
let chunk = input
.blocking_read(want)
.map_err(|e| CoreError::NetworkError(format!("blocking_read: {:?}", e)))?;
if chunk.is_empty() {
return Err(CoreError::NetworkError(
"peer closed the WASI TCP stream (EOF)".into(),
));
}
let take = chunk.len().min(dest.len() - filled);
dest[filled..filled + take].copy_from_slice(&chunk[..take]);
filled += take;
}
Ok(())
}
fn ip_socket_address_from_std(addr: SocketAddr) -> (IpAddressFamily, IpSocketAddress) {
match addr {
SocketAddr::V4(v4) => {
let octets = v4.ip().octets();
(
IpAddressFamily::Ipv4,
IpSocketAddress::Ipv4(Ipv4SocketAddress {
port: v4.port(),
address: (octets[0], octets[1], octets[2], octets[3]),
}),
)
}
SocketAddr::V6(v6) => {
let segs = v6.ip().segments();
(
IpAddressFamily::Ipv6,
IpSocketAddress::Ipv6(Ipv6SocketAddress {
port: v6.port(),
flow_info: v6.flowinfo(),
address: (
segs[0], segs[1], segs[2], segs[3], segs[4], segs[5], segs[6], segs[7],
),
scope_id: v6.scope_id(),
}),
)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ipv4_addr_conversion_round_trips() {
let addr: SocketAddr = "127.0.0.1:4242".parse().unwrap();
let (family, ip) = ip_socket_address_from_std(addr);
assert!(matches!(family, IpAddressFamily::Ipv4));
let IpSocketAddress::Ipv4(v4) = ip else {
panic!("expected ipv4 variant");
};
assert_eq!(v4.port, 4242);
assert_eq!(v4.address, (127, 0, 0, 1));
}
#[test]
fn ipv6_addr_conversion_round_trips() {
let addr: SocketAddr = "[::1]:4242".parse().unwrap();
let (family, ip) = ip_socket_address_from_std(addr);
assert!(matches!(family, IpAddressFamily::Ipv6));
let IpSocketAddress::Ipv6(v6) = ip else {
panic!("expected ipv6 variant");
};
assert_eq!(v6.port, 4242);
assert_eq!(v6.address, (0, 0, 0, 0, 0, 0, 0, 1));
}
}