use std::collections::VecDeque;
use std::io::{Read as IoRead, Write as IoWrite};
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use noxu_sync::{Condvar, Mutex};
use crate::error::{RepError, Result};
pub const MAX_FRAME_PAYLOAD: usize = 64 * 1024 * 1024;
pub trait Channel: Send + Sync {
fn send(&self, data: &[u8]) -> Result<()>;
fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>>;
fn close(&self) -> Result<()>;
fn is_open(&self) -> bool;
}
struct ChannelQueue {
queue: Mutex<VecDeque<Vec<u8>>>,
condvar: Condvar,
writer_closed: AtomicBool,
}
impl ChannelQueue {
fn new() -> Self {
Self {
queue: Mutex::new(VecDeque::new()),
condvar: Condvar::new(),
writer_closed: AtomicBool::new(false),
}
}
fn push(&self, data: Vec<u8>) {
let mut q = self.queue.lock();
q.push_back(data);
self.condvar.notify_one();
}
fn close_writer(&self) {
self.writer_closed.store(true, Ordering::SeqCst);
self.condvar.notify_all();
}
fn pop(
&self,
timeout: Duration,
) -> std::result::Result<Option<Vec<u8>>, ()> {
let mut q = self.queue.lock();
if q.is_empty() {
if self.writer_closed.load(Ordering::SeqCst) {
return Err(());
}
self.condvar.wait_for(&mut q, timeout);
}
if let Some(data) = q.pop_front() {
Ok(Some(data))
} else if self.writer_closed.load(Ordering::SeqCst) {
Err(())
} else {
Ok(None)
}
}
}
pub struct LocalChannel {
send_queue: Arc<ChannelQueue>,
recv_queue: Arc<ChannelQueue>,
open: AtomicBool,
}
impl LocalChannel {
fn new(
send_queue: Arc<ChannelQueue>,
recv_queue: Arc<ChannelQueue>,
) -> Self {
Self { send_queue, recv_queue, open: AtomicBool::new(true) }
}
}
impl Channel for LocalChannel {
fn send(&self, data: &[u8]) -> Result<()> {
if !self.is_open() {
return Err(RepError::ChannelClosed("channel is closed".into()));
}
self.send_queue.push(data.to_vec());
Ok(())
}
fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
if !self.is_open() {
return Err(RepError::ChannelClosed("channel is closed".into()));
}
self.recv_queue.pop(timeout).map_err(|()| {
RepError::ChannelClosed("peer closed the channel".into())
})
}
fn close(&self) -> Result<()> {
self.open.store(false, Ordering::SeqCst);
self.send_queue.close_writer();
self.recv_queue.condvar.notify_all();
Ok(())
}
fn is_open(&self) -> bool {
self.open.load(Ordering::SeqCst)
}
}
pub struct LocalChannelPair {
pub channel_a: LocalChannel,
pub channel_b: LocalChannel,
}
impl LocalChannelPair {
pub fn new() -> Self {
let queue_a_to_b = Arc::new(ChannelQueue::new());
let queue_b_to_a = Arc::new(ChannelQueue::new());
let channel_a = LocalChannel::new(
Arc::clone(&queue_a_to_b),
Arc::clone(&queue_b_to_a),
);
let channel_b = LocalChannel::new(
Arc::clone(&queue_b_to_a),
Arc::clone(&queue_a_to_b),
);
Self { channel_a, channel_b }
}
}
impl Default for LocalChannelPair {
fn default() -> Self {
Self::new()
}
}
pub struct TcpChannel {
stream: Arc<Mutex<TcpStream>>,
open: AtomicBool,
}
impl TcpChannel {
pub fn new(stream: TcpStream) -> Self {
Self {
stream: Arc::new(Mutex::new(stream)),
open: AtomicBool::new(true),
}
}
pub fn connect(addr: SocketAddr) -> Result<Self> {
let stream = TcpStream::connect_timeout(&addr, Duration::from_secs(30))
.map_err(|e| RepError::NetworkError(e.to_string()))?;
Ok(Self::new(stream))
}
pub fn connect_host(host: &str, port: u16) -> Result<Self> {
let addrs: Vec<SocketAddr> = (host, port)
.to_socket_addrs()
.map_err(|e| {
RepError::NetworkError(format!(
"DNS resolution failed for {host}:{port}: {e}"
))
})?
.collect();
if addrs.is_empty() {
return Err(RepError::NetworkError(format!(
"no addresses resolved for {host}:{port}"
)));
}
let mut sorted = addrs;
sorted.sort_by_key(|a| if a.is_ipv6() { 0u8 } else { 1u8 });
let mut last_err = None;
for addr in &sorted {
match TcpStream::connect_timeout(addr, Duration::from_secs(30)) {
Ok(stream) => return Ok(Self::new(stream)),
Err(e) => last_err = Some(e),
}
}
Err(RepError::NetworkError(format!(
"could not connect to {host}:{port}: {}",
last_err.unwrap()
)))
}
pub fn bind_dual_stack(port: u16) -> Result<TcpChannelListener> {
if let Ok(listener) = TcpListener::bind(format!("[::]:{}", port)) {
return Ok(TcpChannelListener { listener });
}
let addr: SocketAddr =
format!("0.0.0.0:{port}").parse().map_err(|e| {
RepError::NetworkError(format!("invalid bind addr: {e}"))
})?;
TcpChannelListener::bind(addr)
}
}
impl Channel for TcpChannel {
fn send(&self, data: &[u8]) -> Result<()> {
if !self.is_open() {
return Err(RepError::ChannelClosed("TcpChannel is closed".into()));
}
let len = data.len() as u32;
let mut stream = self.stream.lock();
stream.set_write_timeout(Some(Duration::from_secs(30))).ok();
stream
.write_all(&len.to_le_bytes())
.map_err(|e| RepError::NetworkError(e.to_string()))?;
stream
.write_all(data)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
stream.flush().map_err(|e| RepError::NetworkError(e.to_string()))?;
Ok(())
}
fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
if !self.is_open() {
return Err(RepError::ChannelClosed("TcpChannel is closed".into()));
}
let mut stream = self.stream.lock();
stream
.set_read_timeout(Some(timeout))
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let mut len_buf = [0u8; 4];
match stream.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::TimedOut
{
return Ok(None);
}
if e.kind() == std::io::ErrorKind::UnexpectedEof {
return Err(RepError::ChannelClosed(
"connection closed by peer".into(),
));
}
return Err(RepError::NetworkError(e.to_string()));
}
}
let payload_len = u32::from_le_bytes(len_buf) as usize;
if payload_len > MAX_FRAME_PAYLOAD {
return Err(RepError::ProtocolError(format!(
"frame payload too large: {} > {}",
payload_len, MAX_FRAME_PAYLOAD
)));
}
let payload_timeout = timeout.max(Duration::from_secs(30));
stream.set_read_timeout(Some(payload_timeout)).ok();
let mut payload = vec![0u8; payload_len];
stream
.read_exact(&mut payload)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
Ok(Some(payload))
}
fn close(&self) -> Result<()> {
self.open.store(false, Ordering::SeqCst);
let stream = self.stream.lock();
stream
.shutdown(std::net::Shutdown::Both)
.map_err(|e| RepError::NetworkError(e.to_string()))
}
fn is_open(&self) -> bool {
self.open.load(Ordering::SeqCst)
}
}
pub struct TcpChannelListener {
listener: TcpListener,
}
impl TcpChannelListener {
pub fn bind(addr: SocketAddr) -> Result<Self> {
let listener = TcpListener::bind(addr)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
Ok(Self { listener })
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.listener
.local_addr()
.map_err(|e| RepError::NetworkError(e.to_string()))
}
pub fn accept(&self) -> Result<TcpChannel> {
let (stream, _peer) = self
.listener
.accept()
.map_err(|e| RepError::NetworkError(e.to_string()))?;
Ok(TcpChannel::new(stream))
}
pub fn set_accept_timeout(&self, timeout: Option<Duration>) -> Result<()> {
#[cfg(unix)]
{
use std::os::fd::AsRawFd;
let fd = self.listener.as_raw_fd();
let tv = match timeout {
Some(d) => libc::timeval {
tv_sec: d.as_secs() as libc::time_t,
tv_usec: d.subsec_micros() as libc::suseconds_t,
},
None => libc::timeval { tv_sec: 0, tv_usec: 0 },
};
let rc = unsafe {
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_RCVTIMEO,
&tv as *const _ as *const libc::c_void,
std::mem::size_of::<libc::timeval>() as libc::socklen_t,
)
};
if rc != 0 {
return Err(RepError::NetworkError(
std::io::Error::last_os_error().to_string(),
));
}
}
#[cfg(not(unix))]
{
let _ = timeout;
}
Ok(())
}
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
use crate::tls::TlsConfig;
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
trait TlsStreamOps: Send + 'static {
fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()>;
fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()>;
fn flush_buf(&mut self) -> std::io::Result<()>;
fn set_read_timeout_inner(
&mut self,
dur: Option<Duration>,
) -> std::io::Result<()>;
fn set_write_timeout_inner(
&mut self,
dur: Option<Duration>,
) -> std::io::Result<()>;
fn shutdown_inner(&self) -> std::io::Result<()>;
}
#[cfg(feature = "tls-rustls")]
impl TlsStreamOps for rustls::StreamOwned<rustls::ServerConnection, TcpStream> {
fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
IoRead::read_exact(self, buf)
}
fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()> {
IoWrite::write_all(self, buf)
}
fn flush_buf(&mut self) -> std::io::Result<()> {
IoWrite::flush(self)
}
fn set_read_timeout_inner(
&mut self,
dur: Option<Duration>,
) -> std::io::Result<()> {
self.sock.set_read_timeout(dur)
}
fn set_write_timeout_inner(
&mut self,
dur: Option<Duration>,
) -> std::io::Result<()> {
self.sock.set_write_timeout(dur)
}
fn shutdown_inner(&self) -> std::io::Result<()> {
self.sock.shutdown(std::net::Shutdown::Both)
}
}
#[cfg(feature = "tls-rustls")]
impl TlsStreamOps for rustls::StreamOwned<rustls::ClientConnection, TcpStream> {
fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
IoRead::read_exact(self, buf)
}
fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()> {
IoWrite::write_all(self, buf)
}
fn flush_buf(&mut self) -> std::io::Result<()> {
IoWrite::flush(self)
}
fn set_read_timeout_inner(
&mut self,
dur: Option<Duration>,
) -> std::io::Result<()> {
self.sock.set_read_timeout(dur)
}
fn set_write_timeout_inner(
&mut self,
dur: Option<Duration>,
) -> std::io::Result<()> {
self.sock.set_write_timeout(dur)
}
fn shutdown_inner(&self) -> std::io::Result<()> {
self.sock.shutdown(std::net::Shutdown::Both)
}
}
#[cfg(feature = "tls-native")]
impl TlsStreamOps for native_tls::TlsStream<TcpStream> {
fn read_exact_buf(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
IoRead::read_exact(self, buf)
}
fn write_all_buf(&mut self, buf: &[u8]) -> std::io::Result<()> {
IoWrite::write_all(self, buf)
}
fn flush_buf(&mut self) -> std::io::Result<()> {
IoWrite::flush(self)
}
fn set_read_timeout_inner(
&mut self,
dur: Option<Duration>,
) -> std::io::Result<()> {
self.get_ref().set_read_timeout(dur)
}
fn set_write_timeout_inner(
&mut self,
dur: Option<Duration>,
) -> std::io::Result<()> {
self.get_ref().set_write_timeout(dur)
}
fn shutdown_inner(&self) -> std::io::Result<()> {
self.get_ref().shutdown(std::net::Shutdown::Both)
}
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
pub struct TlsTcpChannel {
stream: Arc<std::sync::Mutex<Box<dyn TlsStreamOps>>>,
open: AtomicBool,
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
impl TlsTcpChannel {
fn wrap(stream: Box<dyn TlsStreamOps>) -> Self {
Self {
stream: Arc::new(std::sync::Mutex::new(stream)),
open: AtomicBool::new(true),
}
}
pub fn connect_with_tls(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
#[cfg(feature = "tls-rustls")]
{
return Self::connect_rustls(addr, tls);
}
#[cfg(all(feature = "tls-native", not(feature = "tls-rustls")))]
{
return Self::connect_native(addr, tls);
}
#[allow(unreachable_code)]
Err(RepError::NetworkError("no TLS feature enabled".into()))
}
#[cfg(feature = "tls-rustls")]
fn connect_rustls(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
use rustls::pki_types::ServerName;
let cfg = tls.to_rustls_client_config()?;
let server_name = ServerName::try_from(tls.server_name.clone())
.map_err(|e| {
RepError::NetworkError(format!("invalid server name: {e}"))
})?;
let tcp = TcpStream::connect_timeout(&addr, Duration::from_secs(30))
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let conn =
rustls::ClientConnection::new(cfg, server_name).map_err(|e| {
RepError::NetworkError(format!("TLS client init: {e}"))
})?;
let stream = rustls::StreamOwned::new(conn, tcp);
Ok(Self::wrap(Box::new(stream)))
}
#[cfg(feature = "tls-native")]
fn connect_native(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
let connector = tls.to_native_connector()?;
let tcp = TcpStream::connect_timeout(&addr, Duration::from_secs(30))
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let stream = connector.connect(&tls.server_name, tcp).map_err(|e| {
RepError::NetworkError(format!("TLS handshake: {e}"))
})?;
Ok(Self::wrap(Box::new(stream)))
}
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
impl Channel for TlsTcpChannel {
fn send(&self, data: &[u8]) -> Result<()> {
if !self.is_open() {
return Err(RepError::ChannelClosed(
"TlsTcpChannel is closed".into(),
));
}
let len = data.len() as u32;
let mut s = self.stream.lock().map_err(|_| {
RepError::NetworkError("TLS stream lock poisoned".into())
})?;
s.set_write_timeout_inner(Some(Duration::from_secs(30)))
.map_err(|e| RepError::NetworkError(e.to_string()))?;
s.write_all_buf(&len.to_le_bytes())
.map_err(|e| RepError::NetworkError(e.to_string()))?;
s.write_all_buf(data)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
s.flush_buf().map_err(|e| RepError::NetworkError(e.to_string()))?;
Ok(())
}
fn receive(&self, timeout: Duration) -> Result<Option<Vec<u8>>> {
if !self.is_open() {
return Err(RepError::ChannelClosed(
"TlsTcpChannel is closed".into(),
));
}
let mut s = self.stream.lock().map_err(|_| {
RepError::NetworkError("TLS stream lock poisoned".into())
})?;
s.set_read_timeout_inner(Some(timeout))
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let mut len_buf = [0u8; 4];
match s.read_exact_buf(&mut len_buf) {
Ok(()) => {}
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::TimedOut
{
return Ok(None);
}
if e.kind() == std::io::ErrorKind::UnexpectedEof {
return Err(RepError::ChannelClosed(
"connection closed by peer".into(),
));
}
return Err(RepError::NetworkError(e.to_string()));
}
}
let payload_len = u32::from_le_bytes(len_buf) as usize;
if payload_len > MAX_FRAME_PAYLOAD {
return Err(RepError::ProtocolError(format!(
"frame payload too large: {} > {}",
payload_len, MAX_FRAME_PAYLOAD
)));
}
let payload_timeout = timeout.max(Duration::from_secs(30));
s.set_read_timeout_inner(Some(payload_timeout))
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let mut payload = vec![0u8; payload_len];
s.read_exact_buf(&mut payload)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
Ok(Some(payload))
}
fn close(&self) -> Result<()> {
self.open.store(false, Ordering::SeqCst);
let s = self.stream.lock().map_err(|_| {
RepError::NetworkError("TLS stream lock poisoned".into())
})?;
s.shutdown_inner().map_err(|e| RepError::NetworkError(e.to_string()))
}
fn is_open(&self) -> bool {
self.open.load(Ordering::SeqCst)
}
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
enum TlsAcceptorImpl {
#[cfg(feature = "tls-rustls")]
Rustls(std::sync::Arc<rustls::ServerConfig>),
#[cfg(feature = "tls-native")]
Native(native_tls::TlsAcceptor),
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
pub struct TlsTcpChannelListener {
listener: TcpListener,
acceptor: TlsAcceptorImpl,
}
#[cfg(any(feature = "tls-rustls", feature = "tls-native"))]
impl TlsTcpChannelListener {
pub fn bind_with_tls(addr: SocketAddr, tls: &TlsConfig) -> Result<Self> {
let listener = TcpListener::bind(addr)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
#[cfg(feature = "tls-rustls")]
let acceptor = {
let cfg = tls.to_rustls_server_config()?;
TlsAcceptorImpl::Rustls(cfg)
};
#[cfg(all(feature = "tls-native", not(feature = "tls-rustls")))]
let acceptor = {
let a = tls.to_native_acceptor()?;
TlsAcceptorImpl::Native(a)
};
Ok(Self { listener, acceptor })
}
#[cfg(feature = "tls-rustls")]
pub fn bind_with_tls_and_allowlist(
addr: SocketAddr,
tls: &TlsConfig,
allowlist: crate::auth::PeerAllowlist,
) -> Result<Self> {
let listener = TcpListener::bind(addr)
.map_err(|e| RepError::NetworkError(e.to_string()))?;
let cfg = tls.to_rustls_server_config_with_allowlist(allowlist)?;
Ok(Self { listener, acceptor: TlsAcceptorImpl::Rustls(cfg) })
}
pub fn local_addr(&self) -> Result<SocketAddr> {
self.listener
.local_addr()
.map_err(|e| RepError::NetworkError(e.to_string()))
}
pub fn accept(&self) -> Result<TlsTcpChannel> {
let (tcp, _peer) = self
.listener
.accept()
.map_err(|e| RepError::NetworkError(e.to_string()))?;
match &self.acceptor {
#[cfg(feature = "tls-rustls")]
TlsAcceptorImpl::Rustls(cfg) => {
let conn = rustls::ServerConnection::new(Arc::clone(cfg))
.map_err(|e| {
RepError::NetworkError(format!("TLS server init: {e}"))
})?;
let stream = rustls::StreamOwned::new(conn, tcp);
Ok(TlsTcpChannel::wrap(Box::new(stream)))
}
#[cfg(feature = "tls-native")]
TlsAcceptorImpl::Native(acceptor) => {
let stream = acceptor.accept(tcp).map_err(|e| {
RepError::NetworkError(format!("TLS handshake: {e}"))
})?;
Ok(TlsTcpChannel::wrap(Box::new(stream)))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_send_receive_basic() {
let pair = LocalChannelPair::new();
let msg = b"hello world";
pair.channel_a.send(msg).unwrap();
let received = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
assert_eq!(received, Some(msg.to_vec()));
}
#[test]
fn test_bidirectional() {
let pair = LocalChannelPair::new();
pair.channel_a.send(b"from a").unwrap();
pair.channel_b.send(b"from b").unwrap();
let recv_b = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
assert_eq!(recv_b, Some(b"from a".to_vec()));
let recv_a = pair.channel_a.receive(Duration::from_secs(1)).unwrap();
assert_eq!(recv_a, Some(b"from b".to_vec()));
}
#[test]
fn test_multiple_messages_fifo() {
let pair = LocalChannelPair::new();
pair.channel_a.send(b"first").unwrap();
pair.channel_a.send(b"second").unwrap();
pair.channel_a.send(b"third").unwrap();
assert_eq!(
pair.channel_b.receive(Duration::from_secs(1)).unwrap(),
Some(b"first".to_vec())
);
assert_eq!(
pair.channel_b.receive(Duration::from_secs(1)).unwrap(),
Some(b"second".to_vec())
);
assert_eq!(
pair.channel_b.receive(Duration::from_secs(1)).unwrap(),
Some(b"third".to_vec())
);
}
#[test]
fn test_receive_timeout_empty_queue() {
let pair = LocalChannelPair::new();
let result = pair.channel_b.receive(Duration::from_millis(50)).unwrap();
assert_eq!(result, None);
}
#[test]
fn test_send_after_close_fails() {
let pair = LocalChannelPair::new();
pair.channel_a.close().unwrap();
let result = pair.channel_a.send(b"should fail");
assert!(result.is_err());
}
#[test]
fn test_receive_after_close_fails() {
let pair = LocalChannelPair::new();
pair.channel_b.close().unwrap();
let result = pair.channel_b.receive(Duration::from_millis(10));
assert!(result.is_err());
}
#[test]
fn test_is_open() {
let pair = LocalChannelPair::new();
assert!(pair.channel_a.is_open());
assert!(pair.channel_b.is_open());
pair.channel_a.close().unwrap();
assert!(!pair.channel_a.is_open());
assert!(pair.channel_b.is_open());
}
#[test]
fn test_empty_message() {
let pair = LocalChannelPair::new();
pair.channel_a.send(b"").unwrap();
let received = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
assert_eq!(received, Some(vec![]));
}
#[test]
fn test_large_message() {
let pair = LocalChannelPair::new();
let large = vec![0xABu8; 1024 * 1024]; pair.channel_a.send(&large).unwrap();
let received = pair.channel_b.receive(Duration::from_secs(1)).unwrap();
assert_eq!(received, Some(large));
}
#[test]
fn test_concurrent_send_receive() {
let pair = LocalChannelPair::new();
let queue_send = Arc::clone(&pair.channel_a.send_queue);
let _queue_recv = Arc::clone(&pair.channel_b.recv_queue);
let _channel_b_send = Arc::new(ChannelQueue::new());
let _channel_b_recv = Arc::clone(&queue_send);
std::thread::scope(|s| {
let a = &pair.channel_a;
let b = &pair.channel_b;
let handle = s.spawn(|| {
let msg = b.receive(Duration::from_secs(5)).unwrap();
assert_eq!(msg, Some(b"concurrent".to_vec()));
b.send(b"ack").unwrap();
});
a.send(b"concurrent").unwrap();
let ack = a.receive(Duration::from_secs(5)).unwrap();
assert_eq!(ack, Some(b"ack".to_vec()));
handle.join().unwrap();
});
}
#[test]
fn test_default_trait() {
let pair = LocalChannelPair::default();
assert!(pair.channel_a.is_open());
assert!(pair.channel_b.is_open());
}
#[test]
fn test_tcp_channel_send_receive() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let (stream, _) = listener.accept().unwrap();
let ch = TcpChannel::new(stream);
let msg = ch.receive(Duration::from_secs(5)).unwrap();
assert_eq!(msg, Some(b"hello tcp".to_vec()));
ch.send(b"world").unwrap();
});
let client = TcpChannel::connect(addr).unwrap();
client.send(b"hello tcp").unwrap();
let reply = client.receive(Duration::from_secs(5)).unwrap();
assert_eq!(reply, Some(b"world".to_vec()));
handle.join().unwrap();
}
#[test]
fn test_tcp_channel_multiple_messages() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let (stream, _) = listener.accept().unwrap();
let ch = TcpChannel::new(stream);
for i in 0u8..5 {
let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
assert_eq!(msg, vec![i]);
}
});
let client = TcpChannel::connect(addr).unwrap();
for i in 0u8..5 {
client.send(&[i]).unwrap();
}
handle.join().unwrap();
}
#[test]
fn test_tcp_channel_receive_timeout() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let (_stream, _) = listener.accept().unwrap();
std::thread::sleep(Duration::from_secs(2));
});
let client = TcpChannel::connect(addr).unwrap();
let result = client.receive(Duration::from_millis(100)).unwrap();
assert_eq!(result, None, "expected timeout → None");
handle.join().unwrap();
}
#[test]
fn test_tcp_channel_is_open_and_close() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let (_stream, _) = listener.accept().unwrap();
std::thread::sleep(Duration::from_millis(200));
});
let client = TcpChannel::connect(addr).unwrap();
assert!(client.is_open());
client.close().unwrap();
assert!(!client.is_open());
handle.join().unwrap();
}
#[test]
fn test_tcp_channel_large_payload() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let payload: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
let expected = payload.clone();
let handle = std::thread::spawn(move || {
let (stream, _) = listener.accept().unwrap();
let ch = TcpChannel::new(stream);
let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
assert_eq!(msg, expected);
});
let client = TcpChannel::connect(addr).unwrap();
client.send(&payload).unwrap();
handle.join().unwrap();
}
#[test]
fn test_tcp_channel_listener_bind_and_accept() {
let listener =
TcpChannelListener::bind("127.0.0.1:0".parse().unwrap()).unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let msg = ch.receive(Duration::from_secs(5)).unwrap();
assert_eq!(msg, Some(b"ping".to_vec()));
});
let client = TcpChannel::connect(addr).unwrap();
client.send(b"ping").unwrap();
handle.join().unwrap();
}
#[test]
fn test_tcp_channel_rejects_oversize_frame() {
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
let oversized = (crate::net::channel::MAX_FRAME_PAYLOAD as u32)
.saturating_add(1);
stream.write_all(&oversized.to_le_bytes()).unwrap();
std::thread::sleep(Duration::from_millis(200));
});
let client = TcpChannel::connect(addr).unwrap();
let err = client
.receive(Duration::from_secs(5))
.expect_err("oversize frame must be rejected");
match err {
RepError::ProtocolError(msg) => {
assert!(
msg.contains("frame payload too large"),
"unexpected protocol-error message: {}",
msg
);
}
other => panic!("expected ProtocolError, got {:?}", other),
}
handle.join().unwrap();
}
#[cfg(feature = "tls-rustls")]
mod tls_tests {
use super::*;
use crate::tls::TlsConfig;
#[test]
fn test_tls_tcp_send_receive() {
let tls = TlsConfig::insecure("localhost");
let listener = TlsTcpChannelListener::bind_with_tls(
"127.0.0.1:0".parse().unwrap(),
&tls,
)
.unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let msg = ch.receive(Duration::from_secs(5)).unwrap();
assert_eq!(msg, Some(b"hello tls".to_vec()));
ch.send(b"world tls").unwrap();
});
let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
client.send(b"hello tls").unwrap();
let reply = client.receive(Duration::from_secs(5)).unwrap();
assert_eq!(reply, Some(b"world tls".to_vec()));
handle.join().unwrap();
}
#[test]
fn test_tls_tcp_multiple_messages() {
let tls = TlsConfig::insecure("localhost");
let listener = TlsTcpChannelListener::bind_with_tls(
"127.0.0.1:0".parse().unwrap(),
&tls,
)
.unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
for i in 0u8..4 {
let msg =
ch.receive(Duration::from_secs(5)).unwrap().unwrap();
assert_eq!(msg, vec![i]);
}
});
let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
for i in 0u8..4 {
client.send(&[i]).unwrap();
}
handle.join().unwrap();
}
#[test]
fn test_tls_tcp_large_payload() {
let tls = TlsConfig::insecure("localhost");
let listener = TlsTcpChannelListener::bind_with_tls(
"127.0.0.1:0".parse().unwrap(),
&tls,
)
.unwrap();
let addr = listener.local_addr().unwrap();
let payload: Vec<u8> =
(0..65536).map(|i| (i % 256) as u8).collect();
let expected = payload.clone();
let handle = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let msg = ch.receive(Duration::from_secs(5)).unwrap().unwrap();
assert_eq!(msg, expected);
});
let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
client.send(&payload).unwrap();
handle.join().unwrap();
}
#[test]
fn test_tls_tcp_receive_timeout() {
let tls = TlsConfig::insecure("localhost");
let listener = TlsTcpChannelListener::bind_with_tls(
"127.0.0.1:0".parse().unwrap(),
&tls,
)
.unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let _ch = listener.accept().unwrap();
std::thread::sleep(Duration::from_secs(2));
});
let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
let result = client.receive(Duration::from_millis(500)).unwrap();
assert_eq!(result, None, "expected timeout → None");
handle.join().unwrap();
}
#[test]
fn test_tls_tcp_close() {
let tls = TlsConfig::insecure("localhost");
let listener = TlsTcpChannelListener::bind_with_tls(
"127.0.0.1:0".parse().unwrap(),
&tls,
)
.unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let _ch = listener.accept().unwrap();
std::thread::sleep(Duration::from_millis(200));
});
let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
assert!(client.is_open());
client.close().unwrap();
assert!(!client.is_open());
handle.join().unwrap();
}
#[test]
fn test_tls_tcp_rejects_oversize_frame() {
let tls = TlsConfig::insecure("localhost");
let listener = TlsTcpChannelListener::bind_with_tls(
"127.0.0.1:0".parse().unwrap(),
&tls,
)
.unwrap();
let addr = listener.local_addr().unwrap();
let handle = std::thread::spawn(move || {
let ch = listener.accept().unwrap();
let oversized =
vec![0u8; crate::net::channel::MAX_FRAME_PAYLOAD + 1];
let _ = ch.send(&oversized);
});
let client = TlsTcpChannel::connect_with_tls(addr, &tls).unwrap();
let result = client.receive(Duration::from_secs(10));
let _ = client.close();
let err = result.expect_err("oversize TLS frame must be rejected");
match err {
RepError::ProtocolError(msg) => {
assert!(
msg.contains("frame payload too large"),
"unexpected protocol-error message: {}",
msg
);
}
other => panic!("expected ProtocolError, got {:?}", other),
}
let _ = handle.join();
}
}
}