use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
};
use std::time::Duration;
use rs_netty::{
codec::{LineCodec, Utf8DatagramCodec},
datagram_pipeline, pipeline, CloseReason, ConnInfo, ConnectionStats, Context, DatagramContext,
DatagramHandler, Error, Life, Result, TcpClient, TcpServer, UdpClient, UdpServer,
};
#[cfg(feature = "tls")]
use rs_netty::{
Business, BusinessContext, ClientTlsContext, Flow, Inbound, InboundContext, Outbound,
OutboundContext, ServerTlsContext, TlsContextBuilder, TlsInfo,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpStream, UdpSocket},
sync::mpsc,
};
#[tokio::test]
async fn tcp_server_shutdown_stops_server() -> Result<()> {
let life = CountLife::default();
let server = TcpServer::bind("127.0.0.1:0")
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.life(life.clone())
.start()
.await?;
assert_ne!(server.local_addr().port(), 0);
assert_eq!(life.started.load(Ordering::SeqCst), 1);
server.shutdown();
server.wait().await?;
assert_eq!(life.stopped.load(Ordering::SeqCst), 1);
Ok(())
}
#[tokio::test]
async fn udp_server_shutdown_stops_socket() -> Result<()> {
let life = CountLife::default();
let server = UdpServer::bind("127.0.0.1:0")
.pipeline(|| {
datagram_pipeline()
.codec(Utf8DatagramCodec)
.handler(UdpEcho)
})
.life(life.clone())
.start()
.await?;
assert_ne!(server.local_addr().port(), 0);
tokio::task::yield_now().await;
assert_eq!(life.started.load(Ordering::SeqCst), 1);
server.shutdown();
server.wait().await?;
assert_eq!(life.stopped.load(Ordering::SeqCst), 1);
Ok(())
}
#[tokio::test]
async fn tcp_idle_timeout_closes_idle_connection() -> Result<()> {
let life = ReasonLife::default();
let server = TcpServer::bind("127.0.0.1:0")
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.idle_timeout(Duration::from_millis(20))
.life(life.clone())
.start()
.await?;
let _stream = TcpStream::connect(server.local_addr()).await?;
tokio::time::sleep(Duration::from_millis(80)).await;
server.shutdown();
server.wait().await?;
assert!(life.contains(CloseReason::IdleTimeout));
Ok(())
}
#[tokio::test]
async fn tcp_connection_stats_are_opt_in() -> Result<()> {
let seen = Arc::new(Mutex::new(None));
let seen_stats = seen.clone();
let server = TcpServer::bind("127.0.0.1:0")
.pipeline(move || {
pipeline().codec(LineCodec::new()).handler(StatsEcho {
seen: seen_stats.clone(),
})
})
.track_connection_stats()
.start()
.await?;
let mut stream = TcpStream::connect(server.local_addr()).await?;
stream.write_all(b"hello\n").await?;
let mut response = vec![0; 6];
stream.read_exact(&mut response).await?;
drop(stream);
server.shutdown();
server.wait().await?;
let stats = seen.lock().expect("stats").clone().expect("stats");
assert!(stats.bytes_read() >= 6);
assert!(stats.bytes_written() >= 6);
assert_eq!(stats.frames_read(), 1);
assert_eq!(stats.frames_written(), 1);
Ok(())
}
#[tokio::test]
async fn tcp_context_write_and_flush_flushes_before_handler_returns() -> Result<()> {
let server = TcpServer::bind("127.0.0.1:0")
.pipeline(|| pipeline().codec(LineCodec::new()).handler(FlushTwice))
.start()
.await?;
let mut stream = TcpStream::connect(server.local_addr()).await?;
stream.write_all(b"go\n").await?;
let mut first = vec![0; b"first\n".len()];
tokio::time::timeout(Duration::from_millis(50), stream.read_exact(&mut first))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(first, b"first\n");
let mut second = vec![0; b"second\n".len()];
tokio::time::timeout(Duration::from_millis(200), stream.read_exact(&mut second))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(second, b"second\n");
drop(stream);
server.shutdown();
server.wait().await
}
#[tokio::test]
async fn tcp_context_write_and_flush_fire_and_forget_flushes_before_handler_returns() -> Result<()>
{
let server = TcpServer::bind("127.0.0.1:0")
.pipeline(|| {
pipeline()
.codec(LineCodec::new())
.handler(FireAndForgetFlushTwice)
})
.start()
.await?;
let mut stream = TcpStream::connect(server.local_addr()).await?;
stream.write_all(b"go\n").await?;
let mut first = vec![0; b"first\n".len()];
tokio::time::timeout(Duration::from_millis(50), stream.read_exact(&mut first))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(first, b"first\n");
let mut second = vec![0; b"second\n".len()];
tokio::time::timeout(Duration::from_millis(200), stream.read_exact(&mut second))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(second, b"second\n");
drop(stream);
server.shutdown();
server.wait().await
}
#[tokio::test]
async fn udp_context_write_and_flush_sends_before_handler_returns() -> Result<()> {
let server = UdpServer::bind("127.0.0.1:0")
.pipeline(|| {
datagram_pipeline()
.codec(Utf8DatagramCodec)
.handler(UdpFlushTwice)
})
.start()
.await?;
let socket = UdpSocket::bind("127.0.0.1:0").await?;
socket.send_to(b"go", server.local_addr()).await?;
let mut first = vec![0; b"first".len()];
let (first_len, _) =
tokio::time::timeout(Duration::from_millis(50), socket.recv_from(&mut first))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(&first[..first_len], b"first");
let mut second = vec![0; b"second".len()];
let (second_len, _) =
tokio::time::timeout(Duration::from_millis(200), socket.recv_from(&mut second))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(&second[..second_len], b"second");
server.shutdown();
server.wait().await
}
#[tokio::test]
async fn udp_context_write_and_flush_fire_and_forget_sends_before_handler_returns() -> Result<()> {
let server = UdpServer::bind("127.0.0.1:0")
.pipeline(|| {
datagram_pipeline()
.codec(Utf8DatagramCodec)
.handler(UdpFireAndForgetFlushTwice)
})
.start()
.await?;
let socket = UdpSocket::bind("127.0.0.1:0").await?;
socket.send_to(b"go", server.local_addr()).await?;
let mut first = vec![0; b"first".len()];
let (first_len, _) =
tokio::time::timeout(Duration::from_millis(50), socket.recv_from(&mut first))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(&first[..first_len], b"first");
let mut second = vec![0; b"second".len()];
let (second_len, _) =
tokio::time::timeout(Duration::from_millis(200), socket.recv_from(&mut second))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(&second[..second_len], b"second");
server.shutdown();
server.wait().await
}
#[tokio::test]
async fn tcp_context_write_buffers_until_explicit_flush() -> Result<()> {
let server = TcpServer::bind("127.0.0.1:0")
.pipeline(|| pipeline().codec(LineCodec::new()).handler(TcpWriteOnly))
.start()
.await?;
let mut stream = TcpStream::connect(server.local_addr()).await?;
stream.write_all(b"go\n").await?;
let mut byte = [0_u8; 1];
assert!(
tokio::time::timeout(Duration::from_millis(50), stream.read_exact(&mut byte))
.await
.is_err()
);
drop(stream);
server.shutdown();
server.wait().await
}
#[tokio::test]
async fn tcp_context_write_then_flush_sends() -> Result<()> {
let server = TcpServer::bind("127.0.0.1:0")
.pipeline(|| {
pipeline()
.codec(LineCodec::new())
.handler(TcpWriteThenFlush)
})
.start()
.await?;
let mut stream = TcpStream::connect(server.local_addr()).await?;
stream.write_all(b"go\n").await?;
let mut response = vec![0; b"sent\n".len()];
tokio::time::timeout(Duration::from_millis(200), stream.read_exact(&mut response))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(response, b"sent\n");
drop(stream);
server.shutdown();
server.wait().await
}
#[tokio::test]
async fn tcp_channel_write_buffers_until_flush() -> Result<()> {
let server = TcpServer::bind("127.0.0.1:0")
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
let (tx, mut rx) = mpsc::unbounded_channel();
let client = TcpClient::connect(server.local_addr().to_string())
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(NotifyTcp { tx: tx.clone() })
})
.run()
.await?;
let channel = client.channel();
channel.write("hello".to_string()).await?;
assert!(tokio::time::timeout(Duration::from_millis(50), rx.recv())
.await
.is_err());
channel.flush().await?;
let received = tokio::time::timeout(Duration::from_millis(200), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("tcp response channel closed".to_string()))?;
assert_eq!(received, "hello");
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[tokio::test]
async fn udp_context_write_buffers_until_explicit_flush() -> Result<()> {
let server = UdpServer::bind("127.0.0.1:0")
.pipeline(|| {
datagram_pipeline()
.codec(Utf8DatagramCodec)
.handler(UdpWriteOnly)
})
.start()
.await?;
let socket = UdpSocket::bind("127.0.0.1:0").await?;
socket.send_to(b"go", server.local_addr()).await?;
let mut byte = [0_u8; 1];
assert!(
tokio::time::timeout(Duration::from_millis(50), socket.recv_from(&mut byte))
.await
.is_err()
);
server.shutdown();
server.wait().await
}
#[tokio::test]
async fn udp_context_write_then_flush_sends() -> Result<()> {
let server = UdpServer::bind("127.0.0.1:0")
.pipeline(|| {
datagram_pipeline()
.codec(Utf8DatagramCodec)
.handler(UdpWriteThenFlush)
})
.start()
.await?;
let socket = UdpSocket::bind("127.0.0.1:0").await?;
socket.send_to(b"go", server.local_addr()).await?;
let mut response = vec![0; b"sent".len()];
let (len, _) =
tokio::time::timeout(Duration::from_millis(200), socket.recv_from(&mut response))
.await
.map_err(|err| Error::Pipeline(err.to_string()))??;
assert_eq!(&response[..len], b"sent");
server.shutdown();
server.wait().await
}
#[tokio::test]
async fn udp_client_write_preserves_datagrams_until_flush() -> Result<()> {
let server = UdpServer::bind("127.0.0.1:0")
.pipeline(|| {
datagram_pipeline()
.codec(Utf8DatagramCodec)
.handler(UdpEcho)
})
.start()
.await?;
let (tx, mut rx) = mpsc::unbounded_channel();
let client = UdpClient::connect(server.local_addr().to_string())
.pipeline(move || {
datagram_pipeline()
.codec(Utf8DatagramCodec)
.handler(NotifyUdp { tx: tx.clone() })
})
.run()
.await?;
client.write("one".to_string()).await?;
client.write("two".to_string()).await?;
assert!(tokio::time::timeout(Duration::from_millis(50), rx.recv())
.await
.is_err());
client.flush().await?;
let mut received = Vec::new();
for _ in 0..2 {
let value = tokio::time::timeout(Duration::from_millis(200), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("udp response channel closed".to_string()))?;
received.push(value);
}
received.sort();
assert_eq!(received, ["one".to_string(), "two".to_string()]);
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn tls_line_echo_uses_custom_ca_and_host_sni() -> Result<()> {
let (server_tls, client_tls) = localhost_tls_contexts(None)?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(server_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
let (tx, mut rx) = mpsc::unbounded_channel();
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(client_tls)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(NotifyTcp { tx: tx.clone() })
})
.run()
.await?;
client.write_and_flush("hello".to_string()).await?;
let received = tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("tls response channel closed".to_string()))?;
assert_eq!(received, "hello");
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn tls_client_can_override_server_name_for_ip_connection() -> Result<()> {
let (server_tls, client_tls) = localhost_tls_contexts(Some("localhost"))?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(server_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
let client = TcpClient::connect(server.local_addr().to_string())
.tls(client_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await?;
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn tls_client_ip_connection_fails_without_matching_server_name() -> Result<()> {
let (server_tls, client_tls) = localhost_tls_contexts(None)?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(server_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
let result = TcpClient::connect(server.local_addr().to_string())
.tls(client_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await;
assert!(matches!(result, Err(Error::Tls(_))));
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn tls_client_bad_root_fails_without_stopping_server() -> Result<()> {
let (server_tls, client_tls) = localhost_tls_contexts(Some("localhost"))?;
let (_, wrong_client_tls) = localhost_tls_contexts(Some("localhost"))?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(server_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
let failed = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(wrong_client_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await;
assert!(matches!(failed, Err(Error::Tls(_))));
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(client_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await?;
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn tls_context_write_and_flush_flushes_before_handler_returns() -> Result<()> {
let (server_tls, client_tls) = localhost_tls_contexts(None)?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(server_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(FlushTwice))
.start()
.await?;
let (tx, mut rx) = mpsc::unbounded_channel();
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(client_tls)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(NotifyTcp { tx: tx.clone() })
})
.run()
.await?;
client.write_and_flush("go".to_string()).await?;
let first = tokio::time::timeout(Duration::from_millis(50), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("tls response channel closed".to_string()))?;
assert_eq!(first, "first");
let second = tokio::time::timeout(Duration::from_millis(300), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("tls response channel closed".to_string()))?;
assert_eq!(second, "second");
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn mtls_line_echo_uses_client_identity() -> Result<()> {
let contexts = localhost_mtls_contexts()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
let (tx, mut rx) = mpsc::unbounded_channel();
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(contexts.client)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(NotifyTcp { tx: tx.clone() })
})
.run()
.await?;
client.write_and_flush("hello".to_string()).await?;
let received = tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("mtls response channel closed".to_string()))?;
assert_eq!(received, "hello");
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn mtls_client_without_identity_fails_handshake() -> Result<()> {
let contexts = localhost_mtls_contexts()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
assert_mtls_client_rejected(contexts.client_without_identity, server.local_addr().port())
.await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn mtls_client_signed_by_wrong_ca_fails_handshake() -> Result<()> {
let contexts = localhost_mtls_contexts()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
assert_mtls_client_rejected(contexts.wrong_client, server.local_addr().port()).await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn mtls_failed_handshake_does_not_stop_server() -> Result<()> {
let contexts = localhost_mtls_contexts()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
assert_mtls_client_rejected(contexts.wrong_client, server.local_addr().port()).await?;
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(contexts.client)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await?;
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn mtls_context_write_and_flush_flushes_before_handler_returns() -> Result<()> {
let contexts = localhost_mtls_contexts()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(FlushTwice))
.start()
.await?;
let (tx, mut rx) = mpsc::unbounded_channel();
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(contexts.client)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(NotifyTcp { tx: tx.clone() })
})
.run()
.await?;
client.write_and_flush("go".to_string()).await?;
let first = tokio::time::timeout(Duration::from_millis(50), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("mtls response channel closed".to_string()))?;
assert_eq!(first, "first");
let second = tokio::time::timeout(Duration::from_millis(300), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("mtls response channel closed".to_string()))?;
assert_eq!(second, "second");
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn tls_metadata_is_visible_to_handler_and_lifecycle() -> Result<()> {
let (server_tls, client_tls) = localhost_tls_contexts_with_alpn(
vec![b"rs-netty".as_slice()],
vec![b"rs-netty".as_slice()],
None,
)?;
let life = TlsInfoLife::default();
let server_life = life.clone();
let (tx, mut rx) = mpsc::unbounded_channel();
let server = TcpServer::bind("127.0.0.1:0")
.tls(server_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.life(server_life)
.start()
.await?;
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(client_tls)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(NotifyTlsSnapshot { tx: tx.clone() })
})
.run()
.await?;
client.write_and_flush("hello".to_string()).await?;
let snapshot = recv_tls_snapshot(&mut rx).await?;
assert_eq!(snapshot.server_name.as_deref(), Some("localhost"));
assert_eq!(
snapshot.selected_alpn_protocol.as_deref(),
Some(b"rs-netty".as_slice())
);
assert!(snapshot.peer_certificates > 0);
assert!(life
.opened
.lock()
.expect("opened")
.as_ref()
.and_then(ConnInfo::tls)
.is_some());
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn tls_metadata_is_visible_to_all_pipeline_stages() -> Result<()> {
let (server_tls, client_tls) = localhost_tls_contexts(None)?;
let hits = Arc::new(Mutex::new(TlsStageHits::default()));
let server_hits = hits.clone();
let (tx, mut rx) = mpsc::unbounded_channel();
let server = TcpServer::bind("127.0.0.1:0")
.tls(server_tls)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.inbound(RecordTlsInbound {
hits: server_hits.clone(),
})
.business(RecordTlsBusiness {
hits: server_hits.clone(),
})
.handler(RecordTlsHandler {
hits: server_hits.clone(),
})
.outbound(RecordTlsOutbound {
hits: server_hits.clone(),
})
})
.start()
.await?;
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(client_tls)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(NotifyTcp { tx: tx.clone() })
})
.run()
.await?;
client.write_and_flush("hello".to_string()).await?;
let response = tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("tls stage response channel closed".to_string()))?;
assert_eq!(response, "hello");
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await?;
let hits = hits.lock().expect("tls stage hits");
assert!(hits.inbound);
assert!(hits.business);
assert!(hits.handler);
assert!(hits.outbound);
Ok(())
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn tls_pipeline_instance_preserves_client_metadata() -> Result<()> {
let (server_tls, client_tls) = localhost_tls_contexts(None)?;
let (tx, mut rx) = mpsc::unbounded_channel();
let server = TcpServer::bind("127.0.0.1:0")
.tls(server_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(client_tls)
.pipeline_instance(
pipeline()
.codec(LineCodec::new())
.handler(NotifyTlsSnapshot { tx }),
)
.run()
.await?;
client.write_and_flush("hello".to_string()).await?;
let snapshot = recv_tls_snapshot(&mut rx).await?;
assert_eq!(snapshot.server_name.as_deref(), Some("localhost"));
assert!(snapshot.peer_certificates > 0);
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn optional_mtls_allows_clients_without_identity() -> Result<()> {
let contexts = localhost_optional_mtls_contexts()?;
let (tx, mut rx) = mpsc::unbounded_channel();
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(EchoTlsSnapshot { tx: tx.clone() })
})
.start()
.await?;
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(contexts.client_without_identity)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await?;
client.write_and_flush("hello".to_string()).await?;
let snapshot = recv_tls_snapshot(&mut rx).await?;
assert_eq!(snapshot.peer_certificates, 0);
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn optional_mtls_exposes_valid_client_certificate() -> Result<()> {
let contexts = localhost_optional_mtls_contexts()?;
let (tx, mut rx) = mpsc::unbounded_channel();
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(EchoTlsSnapshot { tx: tx.clone() })
})
.start()
.await?;
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(contexts.client)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await?;
client.write_and_flush("hello".to_string()).await?;
let snapshot = recv_tls_snapshot(&mut rx).await?;
assert!(snapshot.peer_certificates > 0);
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn optional_mtls_rejects_untrusted_client_certificate() -> Result<()> {
let contexts = localhost_optional_mtls_contexts()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
assert_mtls_client_rejected(contexts.wrong_client, server.local_addr().port()).await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn optional_mtls_der_identity_is_accepted() -> Result<()> {
let (server_ca_der, server_ca) = test_ca_der()?;
let server_identity = test_identity_der(
&server_ca,
"localhost",
rcgen::ExtendedKeyUsagePurpose::ServerAuth,
)?;
let (client_ca_der, client_ca) = test_ca_der()?;
let client_identity = test_identity_der(
&client_ca,
"client",
rcgen::ExtendedKeyUsagePurpose::ClientAuth,
)?;
let server = TlsContextBuilder::for_server()
.certificate_der(server_identity.cert_der.clone())
.private_key_der(server_identity.key_der.clone())
.client_auth_optional_der(client_ca_der)
.build()?;
let client_tls = TlsContextBuilder::for_client()
.root_certificate_der(server_ca_der)
.client_identity_der([client_identity.cert_der], client_identity.key_der)
.build()?;
let (tx, mut rx) = mpsc::unbounded_channel();
let server = TcpServer::bind("127.0.0.1:0")
.tls(server)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(EchoTlsSnapshot { tx: tx.clone() })
})
.start()
.await?;
let client = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(client_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await?;
client.write_and_flush("hello".to_string()).await?;
let snapshot = recv_tls_snapshot(&mut rx).await?;
assert!(snapshot.peer_certificates > 0);
client.close().await?;
client.wait().await?;
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn alpn_mismatch_rejects_connection() -> Result<()> {
let (server_tls, client_tls) = localhost_tls_contexts_with_alpn(
vec![b"h2".as_slice()],
vec![b"http/1.1".as_slice()],
None,
)?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(server_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.start()
.await?;
let rejected = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(client_tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await;
assert!(matches!(rejected, Err(Error::Tls(_))));
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[test]
fn alpn_rejects_invalid_protocol_names() -> Result<()> {
assert_tls_error_contains(
TlsContextBuilder::for_server()
.alpn_protocols([b"".as_slice()])
.build(),
"ALPN protocol name must not be empty",
);
let root = self_signed_identity("localhost")?;
let too_long = vec![b'a'; 256];
assert_tls_error_contains(
TlsContextBuilder::for_client()
.root_certificate_pem(root.cert_pem.as_bytes())
.alpn_protocols([too_long.as_slice()])
.build(),
"ALPN protocol name must be at most 255 bytes",
);
Ok(())
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn sni_selects_named_certificates_on_one_listener() -> Result<()> {
let contexts = localhost_sni_contexts()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(ReportServerName))
.start()
.await?;
let api = run_sni_client(server.local_addr().port(), contexts.api_client).await?;
let mqtt = run_sni_client(server.local_addr().port(), contexts.mqtt_client).await?;
assert_eq!(api, "api.localhost");
assert_eq!(mqtt, "mqtt.localhost");
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn sni_without_fallback_rejects_unmatched_name_and_keeps_serving() -> Result<()> {
let contexts = localhost_sni_contexts_without_fallback()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(ReportServerName))
.start()
.await?;
let rejected = TcpClient::connect(format!("localhost:{}", server.local_addr().port()))
.tls(contexts.unmatched_client)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await;
assert!(matches!(rejected, Err(Error::Tls(_))));
let api = run_sni_client(server.local_addr().port(), contexts.api_client).await?;
assert_eq!(api, "api.localhost");
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn sni_fallback_serves_non_matching_server_name() -> Result<()> {
let contexts = localhost_sni_contexts()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(contexts.server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(ReportServerName))
.start()
.await?;
let fallback = run_sni_client(server.local_addr().port(), contexts.fallback_client).await?;
assert_eq!(fallback, "localhost");
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[tokio::test]
async fn sni_der_identity_selects_named_certificate() -> Result<()> {
let identity = self_signed_identity_der("der.localhost")?;
let server = TlsContextBuilder::for_server()
.sni_certificate_der(
"der.localhost",
[identity.cert_der.clone()],
identity.key_der.clone(),
)
.build()?;
let client = TlsContextBuilder::for_client()
.root_certificate_der(identity.cert_der)
.server_name("der.localhost")
.build()?;
let server = TcpServer::bind("127.0.0.1:0")
.tls(server)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(ReportServerName))
.start()
.await?;
let server_name = run_sni_client(server.local_addr().port(), client).await?;
assert_eq!(server_name, "der.localhost");
server.shutdown();
server.wait().await
}
#[cfg(feature = "tls")]
#[test]
fn sni_rejects_invalid_identity_options() {
assert_tls_error_contains(
TlsContextBuilder::for_server()
.sni_certificate_der("127.0.0.1", Vec::<Vec<u8>>::new(), Vec::new())
.build(),
"IP addresses are not valid SNI names",
);
assert_tls_error_contains(
TlsContextBuilder::for_server()
.sni_certificate_der("api.localhost", Vec::<Vec<u8>>::new(), Vec::new())
.build(),
"requires a certificate chain",
);
}
#[cfg(feature = "tls")]
fn localhost_tls_contexts(
server_name: Option<&str>,
) -> Result<(ServerTlsContext, ClientTlsContext)> {
localhost_tls_contexts_with_alpn(Vec::new(), Vec::new(), server_name)
}
#[cfg(feature = "tls")]
fn localhost_tls_contexts_with_alpn(
server_alpn: Vec<&'static [u8]>,
client_alpn: Vec<&'static [u8]>,
server_name: Option<&str>,
) -> Result<(ServerTlsContext, ClientTlsContext)> {
let rcgen::CertifiedKey { cert, signing_key } =
rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.map_err(|err| Error::Tls(err.to_string()))?;
let cert_pem = cert.pem();
let key_pem = signing_key.serialize_pem();
let server = TlsContextBuilder::for_server()
.certificate_chain_pem(cert_pem.as_bytes())
.private_key_pem(key_pem.as_bytes())
.alpn_protocols(server_alpn)
.build()?;
let mut client = TlsContextBuilder::for_client()
.root_certificate_pem(cert_pem.as_bytes())
.alpn_protocols(client_alpn);
if let Some(server_name) = server_name {
client = client.server_name(server_name);
}
Ok((server, client.build()?))
}
#[cfg(feature = "tls")]
struct MtlsContextSet {
server: ServerTlsContext,
client: ClientTlsContext,
client_without_identity: ClientTlsContext,
wrong_client: ClientTlsContext,
}
#[cfg(feature = "tls")]
struct TestIdentity {
cert_pem: String,
key_pem: String,
}
#[cfg(feature = "tls")]
fn localhost_mtls_contexts() -> Result<MtlsContextSet> {
let (server_ca_pem, server_ca) = test_ca()?;
let server_identity = test_identity(
&server_ca,
"localhost",
rcgen::ExtendedKeyUsagePurpose::ServerAuth,
)?;
let (client_ca_pem, client_ca) = test_ca()?;
let client_identity = test_identity(
&client_ca,
"client",
rcgen::ExtendedKeyUsagePurpose::ClientAuth,
)?;
let (_, wrong_client_ca) = test_ca()?;
let wrong_client_identity = test_identity(
&wrong_client_ca,
"wrong-client",
rcgen::ExtendedKeyUsagePurpose::ClientAuth,
)?;
let server = TlsContextBuilder::for_server()
.certificate_chain_pem(server_identity.cert_pem.as_bytes())
.private_key_pem(server_identity.key_pem.as_bytes())
.client_auth_required_pem(client_ca_pem.as_bytes())
.build()?;
let client = TlsContextBuilder::for_client()
.root_certificate_pem(server_ca_pem.as_bytes())
.client_identity_pem(
client_identity.cert_pem.as_bytes(),
client_identity.key_pem.as_bytes(),
)
.build()?;
let client_without_identity = TlsContextBuilder::for_client()
.root_certificate_pem(server_ca_pem.as_bytes())
.build()?;
let wrong_client = TlsContextBuilder::for_client()
.root_certificate_pem(server_ca_pem.as_bytes())
.client_identity_pem(
wrong_client_identity.cert_pem.as_bytes(),
wrong_client_identity.key_pem.as_bytes(),
)
.build()?;
Ok(MtlsContextSet {
server,
client,
client_without_identity,
wrong_client,
})
}
#[cfg(feature = "tls")]
fn localhost_optional_mtls_contexts() -> Result<MtlsContextSet> {
let (server_ca_pem, server_ca) = test_ca()?;
let server_identity = test_identity(
&server_ca,
"localhost",
rcgen::ExtendedKeyUsagePurpose::ServerAuth,
)?;
let (client_ca_pem, client_ca) = test_ca()?;
let client_identity = test_identity(
&client_ca,
"client",
rcgen::ExtendedKeyUsagePurpose::ClientAuth,
)?;
let (_, wrong_client_ca) = test_ca()?;
let wrong_client_identity = test_identity(
&wrong_client_ca,
"wrong-client",
rcgen::ExtendedKeyUsagePurpose::ClientAuth,
)?;
let server = TlsContextBuilder::for_server()
.certificate_chain_pem(server_identity.cert_pem.as_bytes())
.private_key_pem(server_identity.key_pem.as_bytes())
.client_auth_optional_pem(client_ca_pem.as_bytes())
.build()?;
let client = TlsContextBuilder::for_client()
.root_certificate_pem(server_ca_pem.as_bytes())
.client_identity_pem(
client_identity.cert_pem.as_bytes(),
client_identity.key_pem.as_bytes(),
)
.build()?;
let client_without_identity = TlsContextBuilder::for_client()
.root_certificate_pem(server_ca_pem.as_bytes())
.build()?;
let wrong_client = TlsContextBuilder::for_client()
.root_certificate_pem(server_ca_pem.as_bytes())
.client_identity_pem(
wrong_client_identity.cert_pem.as_bytes(),
wrong_client_identity.key_pem.as_bytes(),
)
.build()?;
Ok(MtlsContextSet {
server,
client,
client_without_identity,
wrong_client,
})
}
#[cfg(feature = "tls")]
struct SniContextSet {
server: ServerTlsContext,
api_client: ClientTlsContext,
mqtt_client: ClientTlsContext,
fallback_client: ClientTlsContext,
}
#[cfg(feature = "tls")]
struct SniNoFallbackContextSet {
server: ServerTlsContext,
api_client: ClientTlsContext,
unmatched_client: ClientTlsContext,
}
#[cfg(feature = "tls")]
fn localhost_sni_contexts() -> Result<SniContextSet> {
let fallback = self_signed_identity("localhost")?;
let api = self_signed_identity("api.localhost")?;
let mqtt = self_signed_identity("mqtt.localhost")?;
let server = TlsContextBuilder::for_server()
.certificate_chain_pem(fallback.cert_pem.as_bytes())
.private_key_pem(fallback.key_pem.as_bytes())
.sni_certificate_pem(
"api.localhost",
api.cert_pem.as_bytes(),
api.key_pem.as_bytes(),
)
.sni_certificate_pem(
"mqtt.localhost",
mqtt.cert_pem.as_bytes(),
mqtt.key_pem.as_bytes(),
)
.build()?;
let api_client = TlsContextBuilder::for_client()
.root_certificate_pem(api.cert_pem.as_bytes())
.server_name("api.localhost")
.build()?;
let mqtt_client = TlsContextBuilder::for_client()
.root_certificate_pem(mqtt.cert_pem.as_bytes())
.server_name("mqtt.localhost")
.build()?;
let fallback_client = TlsContextBuilder::for_client()
.root_certificate_pem(fallback.cert_pem.as_bytes())
.server_name("localhost")
.build()?;
Ok(SniContextSet {
server,
api_client,
mqtt_client,
fallback_client,
})
}
#[cfg(feature = "tls")]
fn localhost_sni_contexts_without_fallback() -> Result<SniNoFallbackContextSet> {
let api = self_signed_identity("api.localhost")?;
let server = TlsContextBuilder::for_server()
.sni_certificate_pem(
"api.localhost",
api.cert_pem.as_bytes(),
api.key_pem.as_bytes(),
)
.build()?;
let api_client = TlsContextBuilder::for_client()
.root_certificate_pem(api.cert_pem.as_bytes())
.server_name("api.localhost")
.build()?;
let unmatched_client = TlsContextBuilder::for_client()
.root_certificate_pem(api.cert_pem.as_bytes())
.server_name("unknown.localhost")
.build()?;
Ok(SniNoFallbackContextSet {
server,
api_client,
unmatched_client,
})
}
#[cfg(feature = "tls")]
async fn run_sni_client(port: u16, tls: ClientTlsContext) -> Result<String> {
let (tx, mut rx) = mpsc::unbounded_channel();
let client = TcpClient::connect(format!("localhost:{port}"))
.tls(tls)
.pipeline(move || {
pipeline()
.codec(LineCodec::new())
.handler(NotifyTcp { tx: tx.clone() })
})
.run()
.await?;
client.write_and_flush("name".to_string()).await?;
let value = tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("sni response channel closed".to_string()))?;
client.close().await?;
client.wait().await?;
Ok(value)
}
#[cfg(feature = "tls")]
fn self_signed_identity(name: &str) -> Result<TestIdentity> {
let rcgen::CertifiedKey { cert, signing_key } =
rcgen::generate_simple_self_signed(vec![name.to_string()])
.map_err(|err| Error::Tls(err.to_string()))?;
Ok(TestIdentity {
cert_pem: cert.pem(),
key_pem: signing_key.serialize_pem(),
})
}
#[cfg(feature = "tls")]
fn self_signed_identity_der(name: &str) -> Result<TestIdentityDer> {
let rcgen::CertifiedKey { cert, signing_key } =
rcgen::generate_simple_self_signed(vec![name.to_string()])
.map_err(|err| Error::Tls(err.to_string()))?;
Ok(TestIdentityDer {
cert_der: cert.der().to_vec(),
key_der: signing_key.serialize_der(),
})
}
#[cfg(feature = "tls")]
async fn recv_tls_snapshot(rx: &mut mpsc::UnboundedReceiver<TlsSnapshot>) -> Result<TlsSnapshot> {
tokio::time::timeout(Duration::from_millis(500), rx.recv())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?
.ok_or_else(|| Error::Pipeline("tls metadata channel closed".to_string()))
}
#[cfg(feature = "tls")]
async fn assert_mtls_client_rejected(tls: ClientTlsContext, port: u16) -> Result<()> {
let result = TcpClient::connect(format!("localhost:{port}"))
.tls(tls)
.pipeline(|| pipeline().codec(LineCodec::new()).handler(Echo))
.run()
.await;
let client = match result {
Err(Error::Tls(_)) | Err(Error::Io(_)) => return Ok(()),
Err(err) => return Err(err),
Ok(client) => client,
};
let write = client.write_and_flush("hello".to_string()).await;
let wait = tokio::time::timeout(Duration::from_millis(500), client.wait())
.await
.map_err(|err| Error::Pipeline(err.to_string()))?;
assert!(write.is_err() || wait.is_err());
Ok(())
}
#[cfg(feature = "tls")]
fn test_ca() -> Result<(String, rcgen::Issuer<'static, rcgen::KeyPair>)> {
let mut params = rcgen::CertificateParams::new(Vec::<String>::new())
.map_err(|err| Error::Tls(err.to_string()))?;
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params.key_usages = vec![
rcgen::KeyUsagePurpose::DigitalSignature,
rcgen::KeyUsagePurpose::KeyCertSign,
rcgen::KeyUsagePurpose::CrlSign,
];
let key_pair = rcgen::KeyPair::generate().map_err(|err| Error::Tls(err.to_string()))?;
let cert = params
.self_signed(&key_pair)
.map_err(|err| Error::Tls(err.to_string()))?;
Ok((cert.pem(), rcgen::Issuer::new(params, key_pair)))
}
#[cfg(feature = "tls")]
fn test_ca_der() -> Result<(Vec<u8>, rcgen::Issuer<'static, rcgen::KeyPair>)> {
let mut params = rcgen::CertificateParams::new(Vec::<String>::new())
.map_err(|err| Error::Tls(err.to_string()))?;
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params.key_usages = vec![
rcgen::KeyUsagePurpose::DigitalSignature,
rcgen::KeyUsagePurpose::KeyCertSign,
rcgen::KeyUsagePurpose::CrlSign,
];
let key_pair = rcgen::KeyPair::generate().map_err(|err| Error::Tls(err.to_string()))?;
let cert = params
.self_signed(&key_pair)
.map_err(|err| Error::Tls(err.to_string()))?;
Ok((cert.der().to_vec(), rcgen::Issuer::new(params, key_pair)))
}
#[cfg(feature = "tls")]
fn test_identity(
issuer: &rcgen::Issuer<'static, rcgen::KeyPair>,
name: &str,
extended_key_usage: rcgen::ExtendedKeyUsagePurpose,
) -> Result<TestIdentity> {
let mut params = rcgen::CertificateParams::new(vec![name.to_string()])
.map_err(|err| Error::Tls(err.to_string()))?;
params.key_usages = vec![rcgen::KeyUsagePurpose::DigitalSignature];
params.extended_key_usages = vec![extended_key_usage];
let key_pair = rcgen::KeyPair::generate().map_err(|err| Error::Tls(err.to_string()))?;
let cert = params
.signed_by(&key_pair, issuer)
.map_err(|err| Error::Tls(err.to_string()))?;
Ok(TestIdentity {
cert_pem: cert.pem(),
key_pem: key_pair.serialize_pem(),
})
}
#[cfg(feature = "tls")]
fn test_identity_der(
issuer: &rcgen::Issuer<'static, rcgen::KeyPair>,
name: &str,
extended_key_usage: rcgen::ExtendedKeyUsagePurpose,
) -> Result<TestIdentityDer> {
let mut params = rcgen::CertificateParams::new(vec![name.to_string()])
.map_err(|err| Error::Tls(err.to_string()))?;
params.key_usages = vec![rcgen::KeyUsagePurpose::DigitalSignature];
params.extended_key_usages = vec![extended_key_usage];
let key_pair = rcgen::KeyPair::generate().map_err(|err| Error::Tls(err.to_string()))?;
let cert = params
.signed_by(&key_pair, issuer)
.map_err(|err| Error::Tls(err.to_string()))?;
Ok(TestIdentityDer {
cert_der: cert.der().to_vec(),
key_der: key_pair.serialize_der(),
})
}
#[derive(Clone, Default)]
struct CountLife {
started: Arc<AtomicUsize>,
stopped: Arc<AtomicUsize>,
}
impl Life for CountLife {
async fn tcp_server_started(&self, _local_addr: std::net::SocketAddr) -> Result<()> {
self.started.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn tcp_server_stopped(&self, _local_addr: std::net::SocketAddr) -> Result<()> {
self.stopped.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn udp_socket_started(&self, _local_addr: std::net::SocketAddr) -> Result<()> {
self.started.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn udp_socket_stopped(&self, _local_addr: std::net::SocketAddr) -> Result<()> {
self.stopped.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[derive(Clone, Default)]
struct ReasonLife {
reasons: Arc<Mutex<Vec<CloseReason>>>,
}
impl ReasonLife {
fn contains(&self, reason: CloseReason) -> bool {
self.reasons.lock().expect("reasons").contains(&reason)
}
}
impl Life for ReasonLife {
async fn tcp_connection_closed(&self, _info: ConnInfo, reason: CloseReason) -> Result<()> {
self.reasons.lock().expect("reasons").push(reason);
Ok(())
}
}
#[cfg(feature = "tls")]
#[derive(Clone, Default)]
struct TlsInfoLife {
opened: Arc<Mutex<Option<ConnInfo>>>,
}
#[cfg(feature = "tls")]
impl Life for TlsInfoLife {
async fn tcp_connection_opened(&self, info: ConnInfo) -> Result<()> {
*self.opened.lock().expect("opened") = Some(info);
Ok(())
}
}
#[cfg(feature = "tls")]
#[derive(Debug)]
struct TlsSnapshot {
peer_certificates: usize,
selected_alpn_protocol: Option<Vec<u8>>,
server_name: Option<String>,
}
#[cfg(feature = "tls")]
struct TestIdentityDer {
cert_der: Vec<u8>,
key_der: Vec<u8>,
}
#[cfg(feature = "tls")]
#[derive(Default)]
struct TlsStageHits {
inbound: bool,
business: bool,
handler: bool,
outbound: bool,
}
#[cfg(feature = "tls")]
fn tls_snapshot(tls: &TlsInfo) -> TlsSnapshot {
TlsSnapshot {
peer_certificates: tls.peer_certificates().len(),
selected_alpn_protocol: tls.selected_alpn_protocol().map(Vec::from),
server_name: tls.server_name().map(str::to_string),
}
}
#[cfg(feature = "tls")]
fn assert_tls_error_contains<T>(result: Result<T>, expected: &str) {
let error = match result {
Ok(_) => panic!("expected TLS error containing `{expected}`"),
Err(error) => error,
};
match error {
Error::Tls(message) => assert!(
message.contains(expected),
"expected TLS error containing `{expected}`, got `{message}`"
),
other => panic!("expected TLS error containing `{expected}`, got {other:?}"),
}
}
struct Echo;
impl rs_netty::Handler<String> for Echo {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, msg: String) -> Result<()> {
ctx.write_and_flush(msg).await
}
}
#[cfg(feature = "tls")]
struct RecordTlsInbound {
hits: Arc<Mutex<TlsStageHits>>,
}
#[cfg(feature = "tls")]
impl Inbound<String> for RecordTlsInbound {
type Out = String;
async fn read(&mut self, ctx: &mut InboundContext, msg: String) -> Result<Flow<Self::Out>> {
self.hits.lock().expect("tls stage hits").inbound = ctx.tls().is_some();
Ok(Flow::Next(msg))
}
}
#[cfg(feature = "tls")]
struct RecordTlsBusiness {
hits: Arc<Mutex<TlsStageHits>>,
}
#[cfg(feature = "tls")]
impl Business<String> for RecordTlsBusiness {
type Out = String;
async fn handle(&mut self, ctx: &mut BusinessContext, msg: String) -> Result<Flow<Self::Out>> {
self.hits.lock().expect("tls stage hits").business = ctx.tls().is_some();
Ok(Flow::Next(msg))
}
}
#[cfg(feature = "tls")]
struct RecordTlsHandler {
hits: Arc<Mutex<TlsStageHits>>,
}
#[cfg(feature = "tls")]
impl rs_netty::Handler<String> for RecordTlsHandler {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, msg: String) -> Result<()> {
self.hits.lock().expect("tls stage hits").handler = ctx.tls().is_some();
ctx.write_and_flush(msg).await
}
}
#[cfg(feature = "tls")]
struct RecordTlsOutbound {
hits: Arc<Mutex<TlsStageHits>>,
}
#[cfg(feature = "tls")]
impl Outbound<String> for RecordTlsOutbound {
type Out = String;
async fn write(&mut self, ctx: &mut OutboundContext, msg: String) -> Result<Flow<Self::Out>> {
self.hits.lock().expect("tls stage hits").outbound = ctx.tls().is_some();
Ok(Flow::Next(msg))
}
}
#[cfg(feature = "tls")]
struct EchoTlsSnapshot {
tx: mpsc::UnboundedSender<TlsSnapshot>,
}
#[cfg(feature = "tls")]
impl rs_netty::Handler<String> for EchoTlsSnapshot {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, msg: String) -> Result<()> {
let tls = ctx
.tls()
.ok_or_else(|| Error::Pipeline("missing TLS metadata".to_string()))?;
let _ = self.tx.send(tls_snapshot(tls));
ctx.write_and_flush(msg).await
}
}
#[cfg(feature = "tls")]
struct ReportServerName;
#[cfg(feature = "tls")]
impl rs_netty::Handler<String> for ReportServerName {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, _msg: String) -> Result<()> {
let server_name = ctx
.tls()
.and_then(TlsInfo::server_name)
.unwrap_or("")
.to_string();
ctx.write_and_flush(server_name).await
}
}
struct StatsEcho {
seen: Arc<Mutex<Option<ConnectionStats>>>,
}
impl rs_netty::Handler<String> for StatsEcho {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, msg: String) -> Result<()> {
*self.seen.lock().expect("stats") = ctx.stats();
ctx.write_and_flush(msg).await
}
}
struct FlushTwice;
impl rs_netty::Handler<String> for FlushTwice {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, _msg: String) -> Result<()> {
ctx.write_and_flush("first".to_string()).await?;
tokio::time::sleep(Duration::from_millis(100)).await;
ctx.write_and_flush("second".to_string()).await
}
}
struct FireAndForgetFlushTwice;
impl rs_netty::Handler<String> for FireAndForgetFlushTwice {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, _msg: String) -> Result<()> {
ctx.write_and_flush("first".to_string());
tokio::time::sleep(Duration::from_millis(100)).await;
ctx.write_and_flush("second".to_string());
Ok(())
}
}
struct UdpEcho;
impl DatagramHandler<String> for UdpEcho {
type Write = String;
async fn read(&mut self, ctx: &mut DatagramContext<Self::Write>, msg: String) -> Result<()> {
ctx.write_and_flush(msg).await
}
}
struct UdpFlushTwice;
impl DatagramHandler<String> for UdpFlushTwice {
type Write = String;
async fn read(&mut self, ctx: &mut DatagramContext<Self::Write>, _msg: String) -> Result<()> {
ctx.write_and_flush("first".to_string()).await?;
tokio::time::sleep(Duration::from_millis(100)).await;
ctx.write_and_flush("second".to_string()).await
}
}
struct UdpFireAndForgetFlushTwice;
impl DatagramHandler<String> for UdpFireAndForgetFlushTwice {
type Write = String;
async fn read(&mut self, ctx: &mut DatagramContext<Self::Write>, _msg: String) -> Result<()> {
ctx.write_and_flush("first".to_string());
tokio::time::sleep(Duration::from_millis(100)).await;
ctx.write_and_flush("second".to_string());
Ok(())
}
}
struct TcpWriteOnly;
impl rs_netty::Handler<String> for TcpWriteOnly {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, _msg: String) -> Result<()> {
ctx.write("hidden".to_string()).await
}
}
struct TcpWriteThenFlush;
impl rs_netty::Handler<String> for TcpWriteThenFlush {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, _msg: String) -> Result<()> {
ctx.write("sent".to_string()).await?;
ctx.flush().await
}
}
struct NotifyTcp {
tx: mpsc::UnboundedSender<String>,
}
impl rs_netty::Handler<String> for NotifyTcp {
type Write = String;
async fn read(&mut self, _ctx: &mut Context<Self::Write>, msg: String) -> Result<()> {
let _ = self.tx.send(msg);
Ok(())
}
}
#[cfg(feature = "tls")]
struct NotifyTlsSnapshot {
tx: mpsc::UnboundedSender<TlsSnapshot>,
}
#[cfg(feature = "tls")]
impl rs_netty::Handler<String> for NotifyTlsSnapshot {
type Write = String;
async fn read(&mut self, ctx: &mut Context<Self::Write>, _msg: String) -> Result<()> {
let tls = ctx
.tls()
.ok_or_else(|| Error::Pipeline("missing TLS metadata".to_string()))?;
let _ = self.tx.send(tls_snapshot(tls));
Ok(())
}
}
struct UdpWriteOnly;
impl DatagramHandler<String> for UdpWriteOnly {
type Write = String;
async fn read(&mut self, ctx: &mut DatagramContext<Self::Write>, _msg: String) -> Result<()> {
ctx.write("hidden".to_string()).await
}
}
struct UdpWriteThenFlush;
impl DatagramHandler<String> for UdpWriteThenFlush {
type Write = String;
async fn read(&mut self, ctx: &mut DatagramContext<Self::Write>, _msg: String) -> Result<()> {
ctx.write("sent".to_string()).await?;
ctx.flush().await
}
}
struct NotifyUdp {
tx: mpsc::UnboundedSender<String>,
}
impl DatagramHandler<String> for NotifyUdp {
type Write = String;
async fn read(&mut self, _ctx: &mut DatagramContext<Self::Write>, msg: String) -> Result<()> {
let _ = self.tx.send(msg);
Ok(())
}
}