use datum::{Keep, Sink, Source, StreamCompletion, StreamError, testkit::TestSink};
use datum_net::tls::rustls::{
ClientConfig, RootCertStore, ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName},
};
use datum_net::{TlsIncomingConnection, TokioTls};
use rcgen::{CertifiedKey, generate_simple_self_signed};
use std::net::SocketAddr;
use std::sync::{
Arc, Barrier,
atomic::{AtomicUsize, Ordering},
mpsc,
};
use std::thread;
use std::time::{Duration, Instant};
const CHUNK_SIZE: usize = 8192;
fn wait_until(timeout: Duration, condition: impl Fn() -> bool) -> bool {
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
if condition() {
return true;
}
thread::sleep(Duration::from_millis(5));
}
condition()
}
fn assert_failed(error: StreamError) {
assert!(
matches!(
error,
StreamError::Failed(_) | StreamError::AbruptTermination | StreamError::Cancelled
),
"expected TLS failure-shaped stream error, got {error:?}"
);
}
fn tls_configs() -> (
Arc<ServerConfig>,
Arc<ClientConfig>,
Arc<ClientConfig>,
ServerName<'static>,
) {
let CertifiedKey { cert, key_pair } =
generate_simple_self_signed(["localhost".to_owned()]).expect("self-signed cert");
let cert_der: CertificateDer<'static> = cert.der().clone();
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der.clone()], key_der)
.expect("server config");
let mut roots = RootCertStore::empty();
roots.add(cert_der).expect("trust self-signed cert");
let trusted_client_config = ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let untrusted_client_config = ClientConfig::builder()
.with_root_certificates(RootCertStore::empty())
.with_no_client_auth();
(
Arc::new(server_config),
Arc::new(trusted_client_config),
Arc::new(untrusted_client_config),
ServerName::try_from("localhost")
.expect("server name")
.to_owned(),
)
}
#[test]
fn tls_bind_and_outgoing_connection_echo_round_trip() {
let (server_config, client_config, _untrusted, server_name) = tls_configs();
let (binding_completion, incoming_completion) =
TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("tls bind source materializes");
let binding = binding_completion.wait().expect("tls binding succeeds");
let (connection_completion, client_completion) = Source::single(b"ping".to_vec())
.via_mat(
TokioTls::outgoing_connection(
binding.local_addr(),
server_name,
client_config,
CHUNK_SIZE,
),
Keep::right,
)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("client TLS stream materializes");
let incoming = incoming_completion
.wait()
.expect("incoming TLS connection accepted");
let connection = connection_completion
.wait()
.expect("client TLS connection completes handshake");
assert_eq!(connection.remote_addr(), binding.local_addr());
let (incoming_source, incoming_sink) = incoming.into_parts();
let server_read = incoming_source
.run_with(Sink::head())
.expect("server read materializes")
.wait()
.expect("server reads request");
assert_eq!(server_read, b"ping".to_vec());
Source::single(server_read)
.run_with(incoming_sink)
.expect("server write materializes")
.wait()
.expect("server write completes");
assert_eq!(
client_completion.wait().expect("client receives echo"),
b"ping".to_vec()
);
}
#[test]
fn tls_handshake_failure_surfaces_as_stream_error() {
let (server_config, _trusted, untrusted_client_config, server_name) = tls_configs();
let (binding_completion, incoming_completion) =
TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("tls bind source materializes");
let binding = binding_completion.wait().expect("tls binding succeeds");
let (connection_completion, client_completion) = Source::single(b"ping".to_vec())
.via_mat(
TokioTls::outgoing_connection(
binding.local_addr(),
server_name,
untrusted_client_config,
CHUNK_SIZE,
),
Keep::right,
)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("client TLS stream materializes");
assert_failed(
connection_completion
.wait()
.expect_err("client connection materialization fails"),
);
assert_failed(
client_completion
.wait()
.expect_err("client stream fails without panic"),
);
assert_failed(match incoming_completion.wait() {
Ok(_) => panic!("server incoming stream should see handshake failure"),
Err(error) => error,
});
}
#[test]
fn dropping_tls_stream_cancels_pending_read_and_closes_peer_promptly() {
let (server_config, client_config, _untrusted, server_name) = tls_configs();
let (binding_completion, incoming_completion) =
TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("tls bind source materializes");
let binding = binding_completion.wait().expect("tls binding succeeds");
let client_completion = Source::single(b"cancel".to_vec())
.via(TokioTls::outgoing_connection(
binding.local_addr(),
server_name,
client_config,
CHUNK_SIZE,
))
.run_with(Sink::head())
.expect("client stream materializes");
let incoming = incoming_completion
.wait()
.expect("incoming TLS connection accepted");
let (incoming_source, incoming_sink) = incoming.into_parts();
let request = incoming_source
.run_with(Sink::head())
.expect("server read materializes")
.wait()
.expect("server reads request");
assert_eq!(request, b"cancel".to_vec());
drop(client_completion);
let server_write = Source::single(request)
.run_with(incoming_sink)
.expect("server write materializes");
let (done_sender, done_receiver) = mpsc::channel();
thread::spawn(move || {
let _ = done_sender.send(server_write.wait());
});
let result = done_receiver
.recv_timeout(Duration::from_secs(3))
.expect("server write observes peer teardown promptly");
match result {
Ok(_) => {}
Err(error) => assert_failed(error),
}
}
#[test]
fn slow_tls_consumer_backpressures_writer() {
let (server_config, client_config, _untrusted, server_name) = tls_configs();
let (binding_completion, incoming_completion) =
TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("tls bind source materializes");
let binding = binding_completion.wait().expect("tls binding succeeds");
let (connection_completion, probe) = Source::<Vec<u8>>::empty()
.via_mat(
TokioTls::outgoing_connection(
binding.local_addr(),
server_name,
client_config,
CHUNK_SIZE,
),
Keep::right,
)
.to_mat(TestSink::probe(), Keep::both)
.run()
.expect("client probe stream materializes");
probe.request(1);
let incoming: TlsIncomingConnection = incoming_completion
.wait()
.expect("incoming TLS connection accepted");
connection_completion
.wait()
.expect("client TLS connection completes handshake");
let (_incoming_source, incoming_sink) = incoming.into_parts();
let produced = Arc::new(AtomicUsize::new(0));
let total_chunks = 8192usize;
let write_completion: StreamCompletion<datum::NotUsed> = Source::unfold(0usize, {
let produced = Arc::clone(&produced);
move |index| {
if index == total_chunks {
None
} else {
produced.fetch_add(1, Ordering::SeqCst);
Some((index + 1, vec![b'x'; CHUNK_SIZE]))
}
}
})
.run_with(incoming_sink)
.expect("server write materializes");
let first = probe.expect_next();
assert_eq!(first.len(), CHUNK_SIZE);
assert!(
wait_until(Duration::from_secs(1), || produced.load(Ordering::SeqCst)
> 1),
"server writer should begin producing after first downstream demand"
);
assert!(
!wait_until(Duration::from_millis(250), || {
produced.load(Ordering::SeqCst) == total_chunks
}),
"withheld downstream demand should prevent the TLS writer from consuming all chunks"
);
for _ in 1..total_chunks {
probe.request(1);
assert_eq!(probe.expect_next().len(), CHUNK_SIZE);
}
probe.request(1);
probe.expect_complete();
write_completion
.wait()
.expect("server write completes after downstream drains");
}
fn start_raw_tls_echo_server(
server_config: Arc<ServerConfig>,
) -> (SocketAddr, thread::JoinHandle<()>) {
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
let (addr_sender, addr_receiver) = mpsc::channel();
let handle = thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("raw TLS echo server runtime");
rt.block_on(async move {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("raw TLS echo server bind");
let addr = listener.local_addr().expect("raw TLS echo server addr");
addr_sender.send(addr).expect("send addr");
let acceptor = TlsAcceptor::from(server_config);
while let Ok((stream, _)) = listener.accept().await {
let acceptor = acceptor.clone();
tokio::spawn(async move {
let mut tls = match acceptor.accept(stream).await {
Ok(tls) => tls,
Err(_) => return,
};
let mut buf = vec![0u8; CHUNK_SIZE];
loop {
match tokio::io::AsyncReadExt::read(&mut tls, &mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
if tokio::io::AsyncWriteExt::write_all(&mut tls, &buf[..n])
.await
.is_err()
{
break;
}
let _ = tokio::io::AsyncWriteExt::flush(&mut tls).await;
}
}
}
let _ = tokio::io::AsyncWriteExt::shutdown(&mut tls).await;
});
}
});
});
let addr = addr_receiver
.recv_timeout(Duration::from_secs(5))
.expect("raw TLS echo server starts");
(addr, handle)
}
#[test]
fn sharded_tls_success_data_integrity_across_multiple_shards() {
let (server_config, client_config, _untrusted, server_name) = tls_configs();
let (echo_addr, _server_handle) = start_raw_tls_echo_server(server_config);
datum_net::with_sharded_tokio_test_config(
datum_net::ShardedTokioTestConfig {
shard_count: Some(4),
min_connections: Some(1),
},
|| {
let before = datum_net::sharded_tokio_carrier_connection_count();
let barrier = Arc::new(Barrier::new(5));
let handles: Vec<thread::JoinHandle<Vec<u8>>> = (0..4)
.map(|i| {
let barrier = Arc::clone(&barrier);
let payload = format!("shard-{i}").into_bytes();
let client_config = Arc::clone(&client_config);
let server_name = server_name.clone();
thread::spawn(move || {
barrier.wait();
Source::single(payload.clone())
.via(TokioTls::outgoing_connection(
echo_addr,
server_name,
client_config,
CHUNK_SIZE,
))
.run_with(Sink::head())
.expect("client stream materializes")
.wait()
.expect("client receives echo")
})
})
.collect();
barrier.wait();
for (i, handle) in handles.into_iter().enumerate() {
let data = handle.join().expect("client thread panics");
assert_eq!(
data,
format!("shard-{i}").into_bytes(),
"sharded connection {i} received wrong echo"
);
}
let after = datum_net::sharded_tokio_carrier_connection_count();
assert!(
after > before,
"sharded connections should increment counter (before={before}, after={after})"
);
assert!(
after - before >= 4,
"expected at least 4 sharded connections, got {}",
after - before
);
},
);
}
#[test]
fn sharded_tls_handshake_failure_surfaces_error() {
let (server_config, _trusted, untrusted_client_config, server_name) = tls_configs();
let (echo_addr, _server_handle) = start_raw_tls_echo_server(server_config);
datum_net::with_sharded_tokio_test_config(
datum_net::ShardedTokioTestConfig {
shard_count: Some(2),
min_connections: Some(1),
},
move || {
let (connection_completion, client_completion) = Source::single(b"ping".to_vec())
.via_mat(
TokioTls::outgoing_connection(
echo_addr,
server_name,
untrusted_client_config,
CHUNK_SIZE,
),
Keep::right,
)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("client TLS stream materializes");
assert_failed(
connection_completion
.wait()
.expect_err("client connection materialization fails under sharding"),
);
assert_failed(
client_completion
.wait()
.expect_err("client stream fails under sharding"),
);
},
);
}
#[test]
fn sharded_tls_cancellation_tears_down_cleanly() {
let (server_config, client_config, _untrusted, server_name) = tls_configs();
let client_config2 = Arc::clone(&client_config);
let server_name2 = server_name.clone();
let (echo_addr, _server_handle) = start_raw_tls_echo_server(server_config);
datum_net::with_sharded_tokio_test_config(
datum_net::ShardedTokioTestConfig {
shard_count: Some(2),
min_connections: Some(1),
},
move || {
let completion = Source::single(b"cancel".to_vec())
.via(TokioTls::outgoing_connection(
echo_addr,
server_name,
client_config,
CHUNK_SIZE,
))
.run_with(Sink::head())
.expect("client stream materializes");
drop(completion);
Source::single(b"still-works".to_vec())
.via(TokioTls::outgoing_connection(
echo_addr,
server_name2,
client_config2,
CHUNK_SIZE,
))
.run_with(Sink::head())
.expect("second client stream materializes")
.wait()
.expect("second client receives echo after cancellation");
},
);
}
#[test]
fn sharded_tls_fallback_boundary_below_threshold_not_sharded_above_threshold_is_sharded() {
let (server_config, client_config, _untrusted, server_name) = tls_configs();
let server_config1 = Arc::clone(&server_config);
let client_config1 = Arc::clone(&client_config);
let server_name1 = server_name.clone();
let server_config2 = Arc::clone(&server_config);
let client_config2 = Arc::clone(&client_config);
let server_name2 = server_name.clone();
let (echo_addr1, _server_handle1) = start_raw_tls_echo_server(server_config1);
let (echo_addr2, _server_handle2) = start_raw_tls_echo_server(server_config2);
datum_net::with_sharded_tokio_test_config(
datum_net::ShardedTokioTestConfig {
shard_count: Some(2),
min_connections: Some(4),
},
move || {
let mut stable = false;
for attempt in 0..3 {
let before = datum_net::sharded_tokio_carrier_connection_count();
let data = Source::single(b"fallback".to_vec())
.via(TokioTls::outgoing_connection(
echo_addr1,
server_name1.clone(),
Arc::clone(&client_config1),
CHUNK_SIZE,
))
.run_with(Sink::head())
.expect("client stream materializes")
.wait()
.expect("client receives echo");
assert_eq!(data, b"fallback");
let after = datum_net::sharded_tokio_carrier_connection_count();
if after == before {
stable = true;
break;
}
eprintln!(
"sharded-counter interference on attempt {attempt} (before={before}, after={after}); retrying"
);
}
assert!(
stable,
"below min_connections threshold sharded on every attempt — genuine routing regression"
);
},
);
datum_net::with_sharded_tokio_test_config(
datum_net::ShardedTokioTestConfig {
shard_count: Some(2),
min_connections: Some(1),
},
move || {
let before = datum_net::sharded_tokio_carrier_connection_count();
let data = Source::single(b"sharded".to_vec())
.via(TokioTls::outgoing_connection(
echo_addr2,
server_name2,
client_config2,
CHUNK_SIZE,
))
.run_with(Sink::head())
.expect("client stream materializes")
.wait()
.expect("client receives echo");
assert_eq!(data, b"sharded");
let after = datum_net::sharded_tokio_carrier_connection_count();
assert!(
after > before,
"at min_connections threshold should shard (before={before}, after={after})"
);
},
);
}