use std::io;
use std::io::IoSlice;
use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use futures::io::{AsyncReadExt, AsyncWriteExt};
use moonpool_core::{NetworkProvider, Providers, TimeProvider};
use rustls::ClientConnection;
use rustls::pki_types::ServerName;
use crate::EngineError;
use crate::dns::DnsResolver;
use crate::tls::RustlsByteAdapter;
const TLS_WIRE_BUFFER: usize = 16 * 1024;
pub(crate) enum Transport<P: Providers> {
Plain {
stream: <P::Network as NetworkProvider>::TcpStream,
read_scratch: Box<[u8]>,
},
Tls {
stream: <P::Network as NetworkProvider>::TcpStream,
adapter: Box<RustlsByteAdapter>,
plaintext_overflow: BytesMut,
read_scratch: Box<[u8]>,
},
}
impl<P: Providers> Transport<P> {
async fn read_into<S: futures::io::AsyncRead + Unpin>(
stream: &mut S,
scratch: &mut [u8],
buf: &mut BytesMut,
) -> io::Result<usize> {
let n = stream.read(scratch).await?;
buf.extend_from_slice(&scratch[..n]);
Ok(n)
}
pub(crate) async fn connect(
network: &P::Network,
addr: &str,
time: &P::Time,
connect_timeout: Duration,
) -> Result<Self, EngineError> {
tracing::debug!(
addr = %addr,
tls = false,
connect_timeout_ms = u64::try_from(connect_timeout.as_millis()).unwrap_or(u64::MAX),
"dialling broker"
);
let connect_fut = network.connect(addr);
tokio::pin!(connect_fut);
let stream = tokio::select! {
biased;
res = &mut connect_fut => res,
_ = time.sleep(connect_timeout) => Err(io::Error::new(
io::ErrorKind::TimedOut,
format!("connect dial to {addr} exceeded connect_timeout ({connect_timeout:?})"),
)),
}
.map_err(EngineError::Io)?;
Ok(Self::Plain {
stream,
read_scratch: new_read_scratch(),
})
}
pub(crate) async fn connect_with_resolver(
network: &P::Network,
addr: &str,
resolver: Option<&dyn DnsResolver>,
time: &P::Time,
connect_timeout: Duration,
) -> Result<Self, EngineError> {
let Some(resolver) = resolver else {
return Self::connect(network, addr, time, connect_timeout).await;
};
let (host, port) = split_host_port(addr)?;
let addrs = resolver.resolve(host, port).await?;
if addrs.is_empty() {
return Err(EngineError::Config(format!(
"dns resolver returned no addresses for {host}:{port}"
)));
}
let mut last_err: Option<EngineError> = None;
for sa in addrs {
let formatted = sa.to_string();
match Self::connect(network, &formatted, time, connect_timeout).await {
Ok(transport) => return Ok(transport),
Err(e) => last_err = Some(e),
}
}
debug_assert!(
last_err.is_some(),
"all-candidates-failed arm reached without recording any connect error",
);
Err(last_err.unwrap_or_else(|| {
EngineError::Io(io::Error::new(
io::ErrorKind::NotConnected,
"no resolved candidate could be dialled",
))
}))
}
pub(crate) async fn connect_tls(
network: &P::Network,
addr: &str,
host: &str,
tls_config: Arc<rustls::ClientConfig>,
resolver: Option<&dyn DnsResolver>,
time: &P::Time,
connect_timeout: Duration,
) -> Result<Self, EngineError> {
tracing::debug!(
addr = %addr,
host = %host,
tls = true,
connect_timeout_ms = u64::try_from(connect_timeout.as_millis()).unwrap_or(u64::MAX),
"dialling broker"
);
let plain =
Self::connect_with_resolver(network, addr, resolver, time, connect_timeout).await?;
let stream = match plain {
Self::Plain { stream, .. } => stream,
Self::Tls { .. } => unreachable!("connect_with_resolver only yields Plain"),
};
let server_name = ServerName::try_from(host.to_owned()).map_err(|err| {
EngineError::Config(format!("invalid TLS server name {host:?}: {err}"))
})?;
let session = ClientConnection::new(tls_config, server_name).map_err(EngineError::Tls)?;
let mut transport = Self::Tls {
stream,
adapter: Box::new(RustlsByteAdapter::new(session)),
plaintext_overflow: BytesMut::with_capacity(TLS_WIRE_BUFFER),
read_scratch: new_read_scratch(),
};
transport.tls_handshake().await?;
Ok(transport)
}
async fn tls_handshake(&mut self) -> Result<(), EngineError> {
let Self::Tls {
stream,
adapter,
read_scratch,
..
} = self
else {
return Ok(());
};
adapter.step().map_err(EngineError::Tls)?;
while adapter.is_handshaking() {
let out = adapter.take_encrypted_outbound();
if !out.is_empty() {
stream.write_all(&out).await.map_err(EngineError::Io)?;
stream.flush().await.map_err(EngineError::Io)?;
}
if !adapter.is_handshaking() {
break;
}
let n = stream.read(read_scratch).await.map_err(EngineError::Io)?;
if n == 0 {
return Err(EngineError::PeerClosed);
}
adapter.push_encrypted(&read_scratch[..n]);
adapter.step().map_err(EngineError::Tls)?;
}
let trailing = adapter.take_encrypted_outbound();
if !trailing.is_empty() {
stream.write_all(&trailing).await.map_err(EngineError::Io)?;
stream.flush().await.map_err(EngineError::Io)?;
}
Ok(())
}
pub(crate) async fn read_buf(&mut self, buf: &mut bytes::BytesMut) -> io::Result<usize> {
match self {
Self::Plain {
stream,
read_scratch,
} => Self::read_into(stream, read_scratch, buf).await,
Self::Tls {
stream,
adapter,
plaintext_overflow,
read_scratch,
} => {
if !plaintext_overflow.is_empty() {
let n = plaintext_overflow.len();
buf.extend_from_slice(plaintext_overflow);
plaintext_overflow.clear();
return Ok(n);
}
loop {
let read_n = stream.read(read_scratch).await?;
if read_n == 0 {
return Ok(0);
}
adapter.push_encrypted(&read_scratch[..read_n]);
adapter
.step()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
let plaintext = adapter.take_plaintext();
if !plaintext.is_empty() {
buf.extend_from_slice(&plaintext);
return Ok(plaintext.len());
}
}
}
}
}
pub(crate) async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
match self {
Self::Plain { stream, .. } => stream.write_all(buf).await,
Self::Tls {
stream, adapter, ..
} => {
adapter.push_plaintext(buf);
adapter
.step()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
let ciphertext = adapter.take_encrypted_outbound();
if !ciphertext.is_empty() {
stream.write_all(&ciphertext).await?;
}
Ok(())
}
}
}
pub(crate) async fn write_all_vectored(&mut self, segs: &[bytes::Bytes]) -> io::Result<()> {
match self {
Self::Plain { stream, .. } => {
let mut offsets: Vec<usize> = vec![0; segs.len()];
loop {
let slices: Vec<IoSlice<'_>> = segs
.iter()
.zip(offsets.iter())
.filter_map(|(seg, &off)| {
let rest = &seg[off..];
if rest.is_empty() {
None
} else {
Some(IoSlice::new(rest))
}
})
.collect();
if slices.is_empty() {
return Ok(());
}
let n = stream.write_vectored(&slices).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"write_vectored returned 0 with non-empty IoSlice array",
));
}
let mut remaining = n;
for (seg, off) in segs.iter().zip(offsets.iter_mut()) {
let avail = seg.len().saturating_sub(*off);
if avail == 0 {
continue;
}
if remaining >= avail {
*off = seg.len();
remaining -= avail;
} else {
*off += remaining;
remaining = 0;
break;
}
}
debug_assert_eq!(remaining, 0, "kernel reported more bytes than queued");
}
}
Self::Tls {
stream, adapter, ..
} => {
for seg in segs {
adapter.push_plaintext(seg);
}
adapter
.step()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
let ciphertext = adapter.take_encrypted_outbound();
if !ciphertext.is_empty() {
stream.write_all(&ciphertext).await?;
}
Ok(())
}
}
}
pub(crate) async fn flush(&mut self) -> io::Result<()> {
match self {
Self::Plain { stream, .. } => stream.flush().await,
Self::Tls {
stream, adapter, ..
} => {
adapter
.step()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
let pending = adapter.take_encrypted_outbound();
if !pending.is_empty() {
stream.write_all(&pending).await?;
}
stream.flush().await
}
}
}
pub(crate) async fn shutdown(&mut self) -> io::Result<()> {
#[allow(clippy::match_same_arms)]
match self {
Self::Plain { stream, .. } => stream.close().await,
Self::Tls { stream, .. } => stream.close().await,
}
}
}
fn new_read_scratch() -> Box<[u8]> {
vec![0u8; TLS_WIRE_BUFFER].into_boxed_slice()
}
fn split_host_port(addr: &str) -> Result<(&str, u16), EngineError> {
let (host, port) = addr
.rsplit_once(':')
.ok_or_else(|| EngineError::Config(format!("invalid host:port literal {addr:?}")))?;
let host = host.trim_start_matches('[').trim_end_matches(']');
let port: u16 = port
.parse()
.map_err(|e| EngineError::Config(format!("invalid port in {addr:?}: {e}")))?;
Ok((host, port))
}
impl<P: Providers> std::fmt::Debug for Transport<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Plain { .. } => f.debug_struct("Transport::Plain").finish_non_exhaustive(),
Self::Tls { adapter, .. } => f
.debug_struct("Transport::Tls")
.field("is_handshaking", &adapter.is_handshaking())
.finish_non_exhaustive(),
}
}
}
#[cfg(test)]
mod tests {
use super::split_host_port;
#[test]
fn split_host_port_parses_plain() {
let (host, port) = split_host_port("broker:6650").expect("parse");
assert_eq!(host, "broker");
assert_eq!(port, 6650);
}
#[test]
fn split_host_port_strips_ipv6_brackets() {
let (host, port) = split_host_port("[::1]:6650").expect("parse");
assert_eq!(host, "::1");
assert_eq!(port, 6650);
}
#[test]
fn split_host_port_rejects_missing_port() {
assert!(split_host_port("broker").is_err());
}
#[test]
fn split_host_port_rejects_non_numeric_port() {
let err = split_host_port("broker:abc")
.expect_err("non-numeric port must surface as a config error");
assert!(
format!("{err:?}").contains("port"),
"error message should mention port: {err:?}",
);
}
#[test]
fn split_host_port_handles_high_port() {
let (host, port) = split_host_port("broker:65535").expect("parse");
assert_eq!(host, "broker");
assert_eq!(port, 65535);
}
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::io::{AsyncRead, AsyncWriteExt};
use moonpool_core::{NetworkProvider, TcpListenerTrait};
use moonpool_sim::providers::SimProviders;
use moonpool_sim::{NetworkConfiguration, SimWorld};
use super::Transport;
fn try_read(server: &mut (impl AsyncRead + Unpin), buf: &mut [u8]) -> Option<usize> {
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
match Pin::new(server).poll_read(&mut cx, buf) {
Poll::Ready(Ok(n)) if n > 0 => Some(n),
_ => None,
}
}
#[test]
fn write_all_vectored_plain_delivers_segments_in_order() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.expect("build current-thread runtime");
rt.block_on(async move {
let mut sim = SimWorld::new_with_network_config(NetworkConfiguration::fast_local());
let provider = sim.network_provider();
let addr = "vectored-segments";
let listener = provider.bind(addr).await.expect("bind");
let client_stream = provider.connect(addr).await.expect("connect");
let (mut server, _peer) = listener.accept().await.expect("accept");
let mut transport: Transport<SimProviders> = Transport::Plain {
stream: client_stream,
read_scratch: super::new_read_scratch(),
};
let segs = vec![
Bytes::from_static(b"AAAA"),
Bytes::from_static(b"BBBBBB"),
Bytes::from_static(b"CC"),
];
let total: usize = segs.iter().map(Bytes::len).sum();
transport
.write_all_vectored(&segs)
.await
.expect("vectored write");
let mut chunks: Vec<Vec<u8>> = Vec::new();
let mut buf = vec![0u8; 4096];
while sim.pending_event_count() > 0 {
sim.step();
if let Some(n) = try_read(&mut server, &mut buf) {
chunks.push(buf[..n].to_vec());
}
}
assert_eq!(
chunks,
vec![b"AAAA".to_vec(), b"BBBBBB".to_vec(), b"CC".to_vec()],
"each IoSlice must surface as its own ordered delivery event",
);
let reassembled: Vec<u8> = chunks.concat();
assert_eq!(reassembled.len(), total);
});
}
#[test]
fn write_all_vectored_plain_handles_partial_accept() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.enable_time()
.build()
.expect("build current-thread runtime");
rt.block_on(async move {
let mut sim = SimWorld::new_with_network_config(NetworkConfiguration::fast_local());
let provider = sim.network_provider();
let addr = "vectored-partial";
let listener = provider.bind(addr).await.expect("bind");
let client_stream = provider.connect(addr).await.expect("connect");
let (mut server, _peer) = listener.accept().await.expect("accept");
let seg_len = 32 * 1024;
let segs = vec![
Bytes::from(vec![1u8; seg_len]),
Bytes::from(vec![2u8; seg_len]),
Bytes::from(vec![3u8; seg_len]),
];
let mut expected: Vec<u8> = Vec::with_capacity(seg_len * 3);
for s in &segs {
expected.extend_from_slice(s);
}
let total = expected.len();
let done = Arc::new(AtomicBool::new(false));
let done_writer = done.clone();
let writer = tokio::spawn(async move {
transport_write_all_vectored(client_stream, segs).await;
done_writer.store(true, Ordering::SeqCst);
});
let mut received: Vec<u8> = Vec::with_capacity(total);
let mut buf = vec![0u8; 16 * 1024];
for _ in 0..100_000 {
if done.load(Ordering::SeqCst) && received.len() >= total {
break;
}
sim.step();
tokio::task::yield_now().await;
while let Some(n) = try_read(&mut server, &mut buf) {
received.extend_from_slice(&buf[..n]);
}
}
writer.await.expect("writer task joined");
assert_eq!(
received.len(),
total,
"partial-accept loop must flush every byte",
);
assert_eq!(
received, expected,
"reassembled stream must equal the segment concatenation",
);
});
}
async fn transport_write_all_vectored(
stream: <<SimProviders as moonpool_core::Providers>::Network as NetworkProvider>::TcpStream,
segs: Vec<Bytes>,
) {
let mut transport: Transport<SimProviders> = Transport::Plain {
stream,
read_scratch: super::new_read_scratch(),
};
transport
.write_all_vectored(&segs)
.await
.expect("vectored write (partial-accept)");
let _ = AsyncWriteExt::close(&mut match transport {
Transport::Plain { stream, .. } => stream,
Transport::Tls { .. } => unreachable!("constructed Plain"),
})
.await;
}
}