datum-net 0.6.0

Network sources and sinks for Datum streams, built on datum-core
Documentation
use datum::{Keep, Sink, Source, StreamCompletion, StreamError, testkit::TestSink};
use datum_net::tls::rustls::{
    ClientConfig, RootCertStore, ServerConfig,
    pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName},
};
use datum_net::{TlsIncomingConnection, TokioTls};
use rcgen::{CertifiedKey, generate_simple_self_signed};
use std::sync::{
    Arc,
    atomic::{AtomicUsize, Ordering},
    mpsc,
};
use std::thread;
use std::time::{Duration, Instant};

const CHUNK_SIZE: usize = 8192;

fn wait_until(timeout: Duration, condition: impl Fn() -> bool) -> bool {
    let deadline = Instant::now() + timeout;
    while Instant::now() < deadline {
        if condition() {
            return true;
        }
        thread::sleep(Duration::from_millis(5));
    }
    condition()
}

fn assert_failed(error: StreamError) {
    assert!(
        matches!(
            error,
            StreamError::Failed(_) | StreamError::AbruptTermination | StreamError::Cancelled
        ),
        "expected TLS failure-shaped stream error, got {error:?}"
    );
}

fn tls_configs() -> (
    Arc<ServerConfig>,
    Arc<ClientConfig>,
    Arc<ClientConfig>,
    ServerName<'static>,
) {
    let CertifiedKey { cert, key_pair } =
        generate_simple_self_signed(["localhost".to_owned()]).expect("self-signed cert");
    let cert_der: CertificateDer<'static> = cert.der().clone();
    let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));

    let server_config = ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(vec![cert_der.clone()], key_der)
        .expect("server config");

    let mut roots = RootCertStore::empty();
    roots.add(cert_der).expect("trust self-signed cert");
    let trusted_client_config = ClientConfig::builder()
        .with_root_certificates(roots)
        .with_no_client_auth();

    let untrusted_client_config = ClientConfig::builder()
        .with_root_certificates(RootCertStore::empty())
        .with_no_client_auth();

    (
        Arc::new(server_config),
        Arc::new(trusted_client_config),
        Arc::new(untrusted_client_config),
        ServerName::try_from("localhost")
            .expect("server name")
            .to_owned(),
    )
}

#[test]
fn tls_bind_and_outgoing_connection_echo_round_trip() {
    let (server_config, client_config, _untrusted, server_name) = tls_configs();
    let (binding_completion, incoming_completion) =
        TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
            .to_mat(Sink::head(), Keep::both)
            .run()
            .expect("tls bind source materializes");
    let binding = binding_completion.wait().expect("tls binding succeeds");

    let (connection_completion, client_completion) = Source::single(b"ping".to_vec())
        .via_mat(
            TokioTls::outgoing_connection(
                binding.local_addr(),
                server_name,
                client_config,
                CHUNK_SIZE,
            ),
            Keep::right,
        )
        .to_mat(Sink::head(), Keep::both)
        .run()
        .expect("client TLS stream materializes");

    let incoming = incoming_completion
        .wait()
        .expect("incoming TLS connection accepted");
    let connection = connection_completion
        .wait()
        .expect("client TLS connection completes handshake");
    assert_eq!(connection.remote_addr(), binding.local_addr());

    let (incoming_source, incoming_sink) = incoming.into_parts();
    let server_read = incoming_source
        .run_with(Sink::head())
        .expect("server read materializes")
        .wait()
        .expect("server reads request");
    assert_eq!(server_read, b"ping".to_vec());

    Source::single(server_read)
        .run_with(incoming_sink)
        .expect("server write materializes")
        .wait()
        .expect("server write completes");

    assert_eq!(
        client_completion.wait().expect("client receives echo"),
        b"ping".to_vec()
    );
}

#[test]
fn tls_handshake_failure_surfaces_as_stream_error() {
    let (server_config, _trusted, untrusted_client_config, server_name) = tls_configs();
    let (binding_completion, incoming_completion) =
        TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
            .to_mat(Sink::head(), Keep::both)
            .run()
            .expect("tls bind source materializes");
    let binding = binding_completion.wait().expect("tls binding succeeds");

    let (connection_completion, client_completion) = Source::single(b"ping".to_vec())
        .via_mat(
            TokioTls::outgoing_connection(
                binding.local_addr(),
                server_name,
                untrusted_client_config,
                CHUNK_SIZE,
            ),
            Keep::right,
        )
        .to_mat(Sink::head(), Keep::both)
        .run()
        .expect("client TLS stream materializes");

    assert_failed(
        connection_completion
            .wait()
            .expect_err("client connection materialization fails"),
    );
    assert_failed(
        client_completion
            .wait()
            .expect_err("client stream fails without panic"),
    );
    assert_failed(match incoming_completion.wait() {
        Ok(_) => panic!("server incoming stream should see handshake failure"),
        Err(error) => error,
    });
}

#[test]
fn dropping_tls_stream_cancels_pending_read_and_closes_peer_promptly() {
    let (server_config, client_config, _untrusted, server_name) = tls_configs();
    let (binding_completion, incoming_completion) =
        TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
            .to_mat(Sink::head(), Keep::both)
            .run()
            .expect("tls bind source materializes");
    let binding = binding_completion.wait().expect("tls binding succeeds");

    let client_completion = Source::single(b"cancel".to_vec())
        .via(TokioTls::outgoing_connection(
            binding.local_addr(),
            server_name,
            client_config,
            CHUNK_SIZE,
        ))
        .run_with(Sink::head())
        .expect("client stream materializes");

    let incoming = incoming_completion
        .wait()
        .expect("incoming TLS connection accepted");
    let (incoming_source, incoming_sink) = incoming.into_parts();
    let request = incoming_source
        .run_with(Sink::head())
        .expect("server read materializes")
        .wait()
        .expect("server reads request");
    assert_eq!(request, b"cancel".to_vec());

    drop(client_completion);
    let server_write = Source::single(request)
        .run_with(incoming_sink)
        .expect("server write materializes");

    let (done_sender, done_receiver) = mpsc::channel();
    thread::spawn(move || {
        let _ = done_sender.send(server_write.wait());
    });

    let result = done_receiver
        .recv_timeout(Duration::from_secs(3))
        .expect("server write observes peer teardown promptly");
    match result {
        Ok(_) => {}
        Err(error) => assert_failed(error),
    }
}

#[test]
fn slow_tls_consumer_backpressures_writer() {
    let (server_config, client_config, _untrusted, server_name) = tls_configs();
    let (binding_completion, incoming_completion) =
        TokioTls::bind("127.0.0.1:0", server_config, CHUNK_SIZE)
            .to_mat(Sink::head(), Keep::both)
            .run()
            .expect("tls bind source materializes");
    let binding = binding_completion.wait().expect("tls binding succeeds");

    let (connection_completion, probe) = Source::<Vec<u8>>::empty()
        .via_mat(
            TokioTls::outgoing_connection(
                binding.local_addr(),
                server_name,
                client_config,
                CHUNK_SIZE,
            ),
            Keep::right,
        )
        .to_mat(TestSink::probe(), Keep::both)
        .run()
        .expect("client probe stream materializes");

    probe.request(1);
    let incoming: TlsIncomingConnection = incoming_completion
        .wait()
        .expect("incoming TLS connection accepted");
    connection_completion
        .wait()
        .expect("client TLS connection completes handshake");

    let (_incoming_source, incoming_sink) = incoming.into_parts();
    let produced = Arc::new(AtomicUsize::new(0));
    let total_chunks = 8192usize;
    let write_completion: StreamCompletion<datum::NotUsed> = Source::unfold(0usize, {
        let produced = Arc::clone(&produced);
        move |index| {
            if index == total_chunks {
                None
            } else {
                produced.fetch_add(1, Ordering::SeqCst);
                Some((index + 1, vec![b'x'; CHUNK_SIZE]))
            }
        }
    })
    .run_with(incoming_sink)
    .expect("server write materializes");

    let first = probe.expect_next();
    assert_eq!(first.len(), CHUNK_SIZE);
    assert!(
        wait_until(Duration::from_secs(1), || produced.load(Ordering::SeqCst)
            > 1),
        "server writer should begin producing after first downstream demand"
    );
    assert!(
        !wait_until(Duration::from_millis(250), || {
            produced.load(Ordering::SeqCst) == total_chunks
        }),
        "withheld downstream demand should prevent the TLS writer from consuming all chunks"
    );

    for _ in 1..total_chunks {
        probe.request(1);
        assert_eq!(probe.expect_next().len(), CHUNK_SIZE);
    }
    probe.request(1);
    probe.expect_complete();
    write_completion
        .wait()
        .expect("server write completes after downstream drains");
}