datum-net 0.9.0

Network sources and sinks for Datum streams, built on datum-core
Documentation
use datum::{Keep, Sink, Source, StreamError};
use datum_net::tls::rustls::{
    ClientConfig, RootCertStore, ServerConfig,
    pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName},
};
use datum_net::{ConnectionLifecycleExt, ConnectionSettings, RetryPolicy, TokioTls};
use rcgen::{CertifiedKey, generate_simple_self_signed};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::{Arc, mpsc};
use std::thread;
use std::time::{Duration, Instant};

const CHUNK_SIZE: usize = 8192;

fn tls_configs() -> (Arc<ServerConfig>, Arc<ClientConfig>, ServerName<'static>) {
    let CertifiedKey { cert, key_pair } =
        generate_simple_self_signed(["localhost".to_owned()]).expect("self-signed cert");
    let cert_der: CertificateDer<'static> = cert.der().clone();
    let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));

    let server_config = ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(vec![cert_der.clone()], key_der)
        .expect("server config");

    let mut roots = RootCertStore::empty();
    roots.add(cert_der).expect("trust self-signed cert");
    let client_config = ClientConfig::builder()
        .with_root_certificates(roots)
        .with_no_client_auth();

    (
        Arc::new(server_config),
        Arc::new(client_config),
        ServerName::try_from("localhost")
            .expect("server name")
            .to_owned(),
    )
}

fn lifecycle_settings() -> ConnectionSettings {
    ConnectionSettings::default()
        .connect_timeout(Duration::from_millis(100))
        .handshake_timeout(Duration::from_millis(100))
        .retry_policy(
            RetryPolicy::default()
                .max_attempts(1)
                .initial_backoff(Duration::from_millis(10))
                .max_backoff(Duration::from_millis(20)),
        )
}

fn unused_tcp_addr() -> SocketAddr {
    let listener = TcpListener::bind("127.0.0.1:0").expect("reserve TCP port");
    listener.local_addr().expect("reserved TCP local addr")
}

fn assert_failed(error: &StreamError) {
    assert!(
        matches!(
            error,
            StreamError::Failed(_) | StreamError::AbruptTermination | StreamError::Cancelled
        ),
        "expected failure-shaped stream error, got {error:?}"
    );
}

fn assert_error_contains(error: &StreamError, needle: &str) {
    assert_failed(error);
    assert!(
        error.to_string().contains(needle),
        "expected {error:?} to contain {needle:?}"
    );
}

fn accept_one(listener: TcpListener, timeout: Duration) -> TcpStream {
    listener
        .set_nonblocking(true)
        .expect("set listener nonblocking");
    let deadline = Instant::now() + timeout;
    loop {
        match listener.accept() {
            Ok((stream, _)) => return stream,
            Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {
                assert!(
                    Instant::now() < deadline,
                    "timed out waiting for TCP accept"
                );
                thread::sleep(Duration::from_millis(5));
            }
            Err(error) => panic!("TCP accept failed: {error}"),
        }
    }
}

#[test]
fn connect_timeout_surfaces_stream_error_promptly() {
    let (_server_config, client_config, server_name) = tls_configs();
    let settings = lifecycle_settings().connect_timeout(Duration::from_millis(50));
    let started = Instant::now();

    let (connection_completion, stream_completion) = Source::<Vec<u8>>::empty()
        .via_mat(
            TokioTls::outgoing_connection_with_lifecycle(
                "10.255.255.1:9",
                server_name,
                client_config,
                settings,
            ),
            Keep::right,
        )
        .to_mat(Sink::ignore(), Keep::both)
        .run()
        .expect("client stream materializes");

    let error = connection_completion
        .wait()
        .expect_err("black-hole connect times out");
    drop(stream_completion);

    assert_error_contains(&error, "TCP connect timed out");
    assert!(
        started.elapsed() < Duration::from_secs(2),
        "connect timeout should complete promptly, elapsed {:?}",
        started.elapsed()
    );
}

#[test]
fn handshake_timeout_surfaces_stream_error() {
    let listener = TcpListener::bind("127.0.0.1:0").expect("raw TCP listener");
    let addr = listener.local_addr().expect("raw TCP addr");
    let (release_sender, release_receiver) = mpsc::channel();
    let server = thread::spawn(move || {
        let stream = accept_one(listener, Duration::from_secs(3));
        let _ = release_receiver.recv_timeout(Duration::from_secs(3));
        drop(stream);
    });

    let (_server_config, client_config, server_name) = tls_configs();
    let (connection_completion, stream_completion) = Source::<Vec<u8>>::empty()
        .via_mat(
            TokioTls::outgoing_connection_with_lifecycle(addr, server_name, client_config, {
                lifecycle_settings().handshake_timeout(Duration::from_millis(50))
            }),
            Keep::right,
        )
        .to_mat(Sink::ignore(), Keep::both)
        .run()
        .expect("client stream materializes");

    let error = connection_completion
        .wait()
        .expect_err("silent peer handshake times out");
    drop(stream_completion);
    let _ = release_sender.send(());
    server.join().expect("raw server joins");

    assert_error_contains(&error, "TLS handshake timed out");
}

#[test]
fn retry_succeeds_after_initial_establishment_failure() {
    let (server_config, client_config, server_name) = tls_configs();
    let dummy_listener = TcpListener::bind("127.0.0.1:0").expect("dummy TCP listener");
    let addr = dummy_listener.local_addr().expect("dummy TCP addr");
    let (server_done_sender, server_done_receiver) = mpsc::channel();

    let server = thread::spawn(move || {
        let first_attempt = accept_one(dummy_listener, Duration::from_secs(3));
        drop(first_attempt);

        let (binding_completion, incoming_completion) =
            TokioTls::bind(addr, server_config, CHUNK_SIZE)
                .to_mat(Sink::head(), Keep::both)
                .run()
                .expect("real TLS server materializes");
        binding_completion
            .wait()
            .expect("real TLS server binds after first failure");

        let incoming = incoming_completion
            .wait()
            .expect("retry accepts TLS connection");
        let (source, sink) = incoming.into_parts();
        let request = source
            .run_with(Sink::head())
            .expect("server read materializes")
            .wait()
            .expect("server reads request");
        Source::single(request)
            .run_with(sink)
            .expect("server echo materializes")
            .wait()
            .expect("server echo completes");
        server_done_sender.send(()).expect("send server done");
    });

    let settings = lifecycle_settings().retry_policy(
        RetryPolicy::default()
            .max_attempts(8)
            .initial_backoff(Duration::from_millis(10))
            .max_backoff(Duration::from_millis(40)),
    );
    let (connection_completion, response_completion) = Source::single(b"retry".to_vec())
        .via_mat(
            TokioTls::outgoing_connection_with_lifecycle(
                addr,
                server_name,
                client_config,
                settings,
            ),
            Keep::right,
        )
        .to_mat(Sink::head(), Keep::both)
        .run()
        .expect("client stream materializes");

    let connection = connection_completion
        .wait()
        .expect("client connects after retry");
    assert_eq!(connection.remote_addr(), addr);
    assert_eq!(
        response_completion.wait().expect("client receives echo"),
        b"retry".to_vec()
    );

    server_done_receiver
        .recv_timeout(Duration::from_secs(3))
        .expect("server finishes");
    server.join().expect("server thread joins");
}

#[test]
fn retry_exhausted_surfaces_final_stream_error() {
    let (_server_config, client_config, server_name) = tls_configs();
    let settings = lifecycle_settings().retry_policy(
        RetryPolicy::default()
            .max_attempts(3)
            .initial_backoff(Duration::from_millis(10))
            .max_backoff(Duration::from_millis(20)),
    );

    let (connection_completion, stream_completion) = Source::<Vec<u8>>::empty()
        .via_mat(
            TokioTls::outgoing_connection_with_lifecycle(
                unused_tcp_addr(),
                server_name,
                client_config,
                settings,
            ),
            Keep::right,
        )
        .to_mat(Sink::ignore(), Keep::both)
        .run()
        .expect("client stream materializes");

    let error = connection_completion
        .wait()
        .expect_err("retry policy exhausts");
    drop(stream_completion);

    assert_error_contains(&error, "connection establishment failed after 3 attempts");
}

#[test]
fn half_close_keeps_read_side_open_for_response() {
    let (server_config, client_config, server_name) = tls_configs();
    let (binding_completion, incoming_completion) =
        TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
            .to_mat(Sink::head(), Keep::both)
            .run()
            .expect("TLS server materializes");
    let binding = binding_completion.wait().expect("TLS server binds");

    let flow = TokioTls::outgoing_connection_with_lifecycle(
        binding.local_addr(),
        server_name,
        client_config,
        lifecycle_settings(),
    )
    .half_close_on_upstream_finish();
    let (connection_completion, response_completion) = Source::single(b"half-close".to_vec())
        .via_mat(flow, Keep::right)
        .to_mat(Sink::collect(), Keep::both)
        .run()
        .expect("client stream materializes");

    connection_completion
        .wait()
        .expect("client TLS connection completes");
    let incoming = incoming_completion
        .wait()
        .expect("server accepts TLS connection");
    let (source, sink) = incoming.into_parts();
    let request_chunks = source
        .run_with(Sink::collect())
        .expect("server collect materializes")
        .wait()
        .expect("server sees client half-close");
    let request = request_chunks.concat();
    assert_eq!(request, b"half-close".to_vec());

    Source::single(b"response-after-half-close".to_vec())
        .run_with(sink)
        .expect("server response materializes")
        .wait()
        .expect("server response completes");

    let response = response_completion
        .wait()
        .expect("client reads after half-close")
        .concat();
    assert_eq!(response, b"response-after-half-close".to_vec());
}