use datum::{Keep, Sink, Source, StreamError};
use datum_net::tls::rustls::{
ClientConfig, RootCertStore, ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName},
};
use datum_net::{ConnectionLifecycleExt, ConnectionSettings, RetryPolicy, TokioTls};
use rcgen::{CertifiedKey, generate_simple_self_signed};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::{Arc, mpsc};
use std::thread;
use std::time::{Duration, Instant};
const CHUNK_SIZE: usize = 8192;
fn tls_configs() -> (Arc<ServerConfig>, Arc<ClientConfig>, ServerName<'static>) {
let CertifiedKey { cert, key_pair } =
generate_simple_self_signed(["localhost".to_owned()]).expect("self-signed cert");
let cert_der: CertificateDer<'static> = cert.der().clone();
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
let server_config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der.clone()], key_der)
.expect("server config");
let mut roots = RootCertStore::empty();
roots.add(cert_der).expect("trust self-signed cert");
let client_config = ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
(
Arc::new(server_config),
Arc::new(client_config),
ServerName::try_from("localhost")
.expect("server name")
.to_owned(),
)
}
fn lifecycle_settings() -> ConnectionSettings {
ConnectionSettings::default()
.connect_timeout(Duration::from_millis(100))
.handshake_timeout(Duration::from_millis(100))
.retry_policy(
RetryPolicy::default()
.max_attempts(1)
.initial_backoff(Duration::from_millis(10))
.max_backoff(Duration::from_millis(20)),
)
}
fn unused_tcp_addr() -> SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").expect("reserve TCP port");
listener.local_addr().expect("reserved TCP local addr")
}
fn assert_failed(error: &StreamError) {
assert!(
matches!(
error,
StreamError::Failed(_) | StreamError::AbruptTermination | StreamError::Cancelled
),
"expected failure-shaped stream error, got {error:?}"
);
}
fn assert_error_contains(error: &StreamError, needle: &str) {
assert_failed(error);
assert!(
error.to_string().contains(needle),
"expected {error:?} to contain {needle:?}"
);
}
fn accept_one(listener: TcpListener, timeout: Duration) -> TcpStream {
listener
.set_nonblocking(true)
.expect("set listener nonblocking");
let deadline = Instant::now() + timeout;
loop {
match listener.accept() {
Ok((stream, _)) => return stream,
Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {
assert!(
Instant::now() < deadline,
"timed out waiting for TCP accept"
);
thread::sleep(Duration::from_millis(5));
}
Err(error) => panic!("TCP accept failed: {error}"),
}
}
}
#[test]
fn connect_timeout_surfaces_stream_error_promptly() {
let (_server_config, client_config, server_name) = tls_configs();
let settings = lifecycle_settings().connect_timeout(Duration::from_millis(50));
let started = Instant::now();
let (connection_completion, stream_completion) = Source::<Vec<u8>>::empty()
.via_mat(
TokioTls::outgoing_connection_with_lifecycle(
"10.255.255.1:9",
server_name,
client_config,
settings,
),
Keep::right,
)
.to_mat(Sink::ignore(), Keep::both)
.run()
.expect("client stream materializes");
let error = connection_completion
.wait()
.expect_err("black-hole connect times out");
drop(stream_completion);
assert_error_contains(&error, "TCP connect timed out");
assert!(
started.elapsed() < Duration::from_secs(2),
"connect timeout should complete promptly, elapsed {:?}",
started.elapsed()
);
}
#[test]
fn handshake_timeout_surfaces_stream_error() {
let listener = TcpListener::bind("127.0.0.1:0").expect("raw TCP listener");
let addr = listener.local_addr().expect("raw TCP addr");
let (release_sender, release_receiver) = mpsc::channel();
let server = thread::spawn(move || {
let stream = accept_one(listener, Duration::from_secs(3));
let _ = release_receiver.recv_timeout(Duration::from_secs(3));
drop(stream);
});
let (_server_config, client_config, server_name) = tls_configs();
let (connection_completion, stream_completion) = Source::<Vec<u8>>::empty()
.via_mat(
TokioTls::outgoing_connection_with_lifecycle(addr, server_name, client_config, {
lifecycle_settings().handshake_timeout(Duration::from_millis(50))
}),
Keep::right,
)
.to_mat(Sink::ignore(), Keep::both)
.run()
.expect("client stream materializes");
let error = connection_completion
.wait()
.expect_err("silent peer handshake times out");
drop(stream_completion);
let _ = release_sender.send(());
server.join().expect("raw server joins");
assert_error_contains(&error, "TLS handshake timed out");
}
#[test]
fn retry_succeeds_after_initial_establishment_failure() {
let (server_config, client_config, server_name) = tls_configs();
let dummy_listener = TcpListener::bind("127.0.0.1:0").expect("dummy TCP listener");
let addr = dummy_listener.local_addr().expect("dummy TCP addr");
let (server_done_sender, server_done_receiver) = mpsc::channel();
let server = thread::spawn(move || {
let first_attempt = accept_one(dummy_listener, Duration::from_secs(3));
drop(first_attempt);
let (binding_completion, incoming_completion) =
TokioTls::bind(addr, server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("real TLS server materializes");
binding_completion
.wait()
.expect("real TLS server binds after first failure");
let incoming = incoming_completion
.wait()
.expect("retry accepts TLS connection");
let (source, sink) = incoming.into_parts();
let request = source
.run_with(Sink::head())
.expect("server read materializes")
.wait()
.expect("server reads request");
Source::single(request)
.run_with(sink)
.expect("server echo materializes")
.wait()
.expect("server echo completes");
server_done_sender.send(()).expect("send server done");
});
let settings = lifecycle_settings().retry_policy(
RetryPolicy::default()
.max_attempts(8)
.initial_backoff(Duration::from_millis(10))
.max_backoff(Duration::from_millis(40)),
);
let (connection_completion, response_completion) = Source::single(b"retry".to_vec())
.via_mat(
TokioTls::outgoing_connection_with_lifecycle(
addr,
server_name,
client_config,
settings,
),
Keep::right,
)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("client stream materializes");
let connection = connection_completion
.wait()
.expect("client connects after retry");
assert_eq!(connection.remote_addr(), addr);
assert_eq!(
response_completion.wait().expect("client receives echo"),
b"retry".to_vec()
);
server_done_receiver
.recv_timeout(Duration::from_secs(3))
.expect("server finishes");
server.join().expect("server thread joins");
}
#[test]
fn retry_exhausted_surfaces_final_stream_error() {
let (_server_config, client_config, server_name) = tls_configs();
let settings = lifecycle_settings().retry_policy(
RetryPolicy::default()
.max_attempts(3)
.initial_backoff(Duration::from_millis(10))
.max_backoff(Duration::from_millis(20)),
);
let (connection_completion, stream_completion) = Source::<Vec<u8>>::empty()
.via_mat(
TokioTls::outgoing_connection_with_lifecycle(
unused_tcp_addr(),
server_name,
client_config,
settings,
),
Keep::right,
)
.to_mat(Sink::ignore(), Keep::both)
.run()
.expect("client stream materializes");
let error = connection_completion
.wait()
.expect_err("retry policy exhausts");
drop(stream_completion);
assert_error_contains(&error, "connection establishment failed after 3 attempts");
}
#[test]
fn half_close_keeps_read_side_open_for_response() {
let (server_config, client_config, server_name) = tls_configs();
let (binding_completion, incoming_completion) =
TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.expect("TLS server materializes");
let binding = binding_completion.wait().expect("TLS server binds");
let flow = TokioTls::outgoing_connection_with_lifecycle(
binding.local_addr(),
server_name,
client_config,
lifecycle_settings(),
)
.half_close_on_upstream_finish();
let (connection_completion, response_completion) = Source::single(b"half-close".to_vec())
.via_mat(flow, Keep::right)
.to_mat(Sink::collect(), Keep::both)
.run()
.expect("client stream materializes");
connection_completion
.wait()
.expect("client TLS connection completes");
let incoming = incoming_completion
.wait()
.expect("server accepts TLS connection");
let (source, sink) = incoming.into_parts();
let request_chunks = source
.run_with(Sink::collect())
.expect("server collect materializes")
.wait()
.expect("server sees client half-close");
let request = request_chunks.concat();
assert_eq!(request, b"half-close".to_vec());
Source::single(b"response-after-half-close".to_vec())
.run_with(sink)
.expect("server response materializes")
.wait()
.expect("server response completes");
let response = response_completion
.wait()
.expect("client reads after half-close")
.concat();
assert_eq!(response, b"response-after-half-close".to_vec());
}