use datum::{Keep, Sink, Source, StreamCompletion, StreamError, testkit::TestSink};
use datum_net::tls::rustls::{
ClientConfig, RootCertStore, ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName},
};
use datum_net::{TlsIncomingConnection, TokioTls};
use rcgen::{CertifiedKey, generate_simple_self_signed};
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
mpsc,
};
use std::thread;
use std::time::{Duration, Instant};
const CHUNK_SIZE: usize = 8192;
fn wait_until(timeout: Duration, condition: impl Fn() -> bool) -> bool {
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
if condition() {
return true;
}
thread::sleep(Duration::from_millis(5));
}
condition()
}
fn assert_failed(error: StreamError) {
assert!(
matches!(
error,
StreamError::Failed(_) | StreamError::AbruptTermination | StreamError::Cancelled
),
"expected TLS failure-shaped stream error, got {error:?}"
);
}
fn tls_configs() -> (
Arc<ServerConfig>,
Arc<ClientConfig>,
Arc<ClientConfig>,
ServerName<'static>,
) {
let CertifiedKey { cert, key_pair } =
generate_simple_self_signed(["localhost".to_owned()]).expect("self-signed cert");
let cert_der: CertificateDer<'static> = cert.der().clone();
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der.clone()], key_der)
.expect("server config");
let mut roots = RootCertStore::empty();
roots.add(cert_der).expect("trust self-signed cert");
let trusted_client_config = ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let untrusted_client_config = ClientConfig::builder()
.with_root_certificates(RootCertStore::empty())
.with_no_client_auth();
(
Arc::new(server_config),
Arc::new(trusted_client_config),
Arc::new(untrusted_client_config),
ServerName::try_from("localhost")
.expect("server name")
.to_owned(),
)
}
#[test]
fn tls_bind_and_outgoing_connection_echo_round_trip() {
let (server_config, client_config, _untrusted, server_name) = tls_configs();
let (binding_completion, incoming_completion) =
TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("tls bind source materializes");
let binding = binding_completion.wait().expect("tls binding succeeds");
let (connection_completion, client_completion) = Source::single(b"ping".to_vec())
.via_mat(
TokioTls::outgoing_connection(
binding.local_addr(),
server_name,
client_config,
CHUNK_SIZE,
),
Keep::right,
)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("client TLS stream materializes");
let incoming = incoming_completion
.wait()
.expect("incoming TLS connection accepted");
let connection = connection_completion
.wait()
.expect("client TLS connection completes handshake");
assert_eq!(connection.remote_addr(), binding.local_addr());
let (incoming_source, incoming_sink) = incoming.into_parts();
let server_read = incoming_source
.run_with(Sink::head())
.expect("server read materializes")
.wait()
.expect("server reads request");
assert_eq!(server_read, b"ping".to_vec());
Source::single(server_read)
.run_with(incoming_sink)
.expect("server write materializes")
.wait()
.expect("server write completes");
assert_eq!(
client_completion.wait().expect("client receives echo"),
b"ping".to_vec()
);
}
#[test]
fn tls_handshake_failure_surfaces_as_stream_error() {
let (server_config, _trusted, untrusted_client_config, server_name) = tls_configs();
let (binding_completion, incoming_completion) =
TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("tls bind source materializes");
let binding = binding_completion.wait().expect("tls binding succeeds");
let (connection_completion, client_completion) = Source::single(b"ping".to_vec())
.via_mat(
TokioTls::outgoing_connection(
binding.local_addr(),
server_name,
untrusted_client_config,
CHUNK_SIZE,
),
Keep::right,
)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("client TLS stream materializes");
assert_failed(
connection_completion
.wait()
.expect_err("client connection materialization fails"),
);
assert_failed(
client_completion
.wait()
.expect_err("client stream fails without panic"),
);
assert_failed(match incoming_completion.wait() {
Ok(_) => panic!("server incoming stream should see handshake failure"),
Err(error) => error,
});
}
#[test]
fn dropping_tls_stream_cancels_pending_read_and_closes_peer_promptly() {
let (server_config, client_config, _untrusted, server_name) = tls_configs();
let (binding_completion, incoming_completion) =
TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("tls bind source materializes");
let binding = binding_completion.wait().expect("tls binding succeeds");
let client_completion = Source::single(b"cancel".to_vec())
.via(TokioTls::outgoing_connection(
binding.local_addr(),
server_name,
client_config,
CHUNK_SIZE,
))
.run_with(Sink::head())
.expect("client stream materializes");
let incoming = incoming_completion
.wait()
.expect("incoming TLS connection accepted");
let (incoming_source, incoming_sink) = incoming.into_parts();
let request = incoming_source
.run_with(Sink::head())
.expect("server read materializes")
.wait()
.expect("server reads request");
assert_eq!(request, b"cancel".to_vec());
drop(client_completion);
let server_write = Source::single(request)
.run_with(incoming_sink)
.expect("server write materializes");
let (done_sender, done_receiver) = mpsc::channel();
thread::spawn(move || {
let _ = done_sender.send(server_write.wait());
});
let result = done_receiver
.recv_timeout(Duration::from_secs(3))
.expect("server write observes peer teardown promptly");
match result {
Ok(_) => {}
Err(error) => assert_failed(error),
}
}
#[test]
fn slow_tls_consumer_backpressures_writer() {
let (server_config, client_config, _untrusted, server_name) = tls_configs();
let (binding_completion, incoming_completion) =
TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("tls bind source materializes");
let binding = binding_completion.wait().expect("tls binding succeeds");
let (connection_completion, probe) = Source::<Vec<u8>>::empty()
.via_mat(
TokioTls::outgoing_connection(
binding.local_addr(),
server_name,
client_config,
CHUNK_SIZE,
),
Keep::right,
)
.to_mat(TestSink::probe(), Keep::both)
.run()
.expect("client probe stream materializes");
probe.request(1);
let incoming: TlsIncomingConnection = incoming_completion
.wait()
.expect("incoming TLS connection accepted");
connection_completion
.wait()
.expect("client TLS connection completes handshake");
let (_incoming_source, incoming_sink) = incoming.into_parts();
let produced = Arc::new(AtomicUsize::new(0));
let total_chunks = 8192usize;
let write_completion: StreamCompletion<datum::NotUsed> = Source::unfold(0usize, {
let produced = Arc::clone(&produced);
move |index| {
if index == total_chunks {
None
} else {
produced.fetch_add(1, Ordering::SeqCst);
Some((index + 1, vec![b'x'; CHUNK_SIZE]))
}
}
})
.run_with(incoming_sink)
.expect("server write materializes");
let first = probe.expect_next();
assert_eq!(first.len(), CHUNK_SIZE);
assert!(
wait_until(Duration::from_secs(1), || produced.load(Ordering::SeqCst)
> 1),
"server writer should begin producing after first downstream demand"
);
assert!(
!wait_until(Duration::from_millis(250), || {
produced.load(Ordering::SeqCst) == total_chunks
}),
"withheld downstream demand should prevent the TLS writer from consuming all chunks"
);
for _ in 1..total_chunks {
probe.request(1);
assert_eq!(probe.expect_next().len(), CHUNK_SIZE);
}
probe.request(1);
probe.expect_complete();
write_completion
.wait()
.expect("server write completes after downstream drains");
}