datum-net 0.9.0

Network sources and sinks for Datum streams, built on datum-core
Documentation
use datum::{Keep, Sink, Source, StreamRefId, StreamRefSettings, StreamRefs};
use datum_net::quic::{
    crypto::rustls::{QuicClientConfig, QuicServerConfig},
    quinn,
    rustls::{
        ClientConfig as RustlsClientConfig, ServerConfig as RustlsServerConfig,
        pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName, UnixTime},
    },
};
use datum_net::{
    TokioQuic, serve_sink_ref_over_quic, serve_source_ref_over_quic, sink_ref_over_quic,
    source_ref_over_quic,
};
use rcgen::{CertifiedKey, generate_simple_self_signed};
use std::{env, net::SocketAddr, process, sync::Arc, thread, time::Duration};

const STREAM_REF_ID: StreamRefId = StreamRefId::from_u128(1);
const STREAM_REF_CHUNK_SIZE: usize = 8192;

enum Mode {
    Serve {
        addr: String,
        count: u64,
        sink: bool,
    },
    Connect {
        addr: String,
        count: u64,
        slow_ms: u64,
        sink: bool,
    },
}

fn main() {
    if let Err(error) = run() {
        eprintln!("ERROR {error}");
        process::exit(1);
    }
}

fn run() -> Result<(), String> {
    match parse_args()? {
        Mode::Serve { addr, count, sink } => {
            if sink {
                run_sink_server(addr, count)
            } else {
                run_server(addr, count)
            }
        }
        Mode::Connect {
            addr,
            count,
            slow_ms,
            sink,
        } => {
            if sink {
                run_sink_client(addr, count)
            } else {
                run_client(addr, count, slow_ms)
            }
        }
    }
}

fn run_server(addr: String, count: u64) -> Result<(), String> {
    let (server_config, _client_config) = quic_configs()?;
    let (binding_completion, incoming_completion) =
        TokioQuic::bind(addr, server_config, STREAM_REF_CHUNK_SIZE)
            .to_mat(Sink::head(), Keep::both)
            .run()
            .map_err(display)?;
    let binding = binding_completion.wait().map_err(display)?;
    eprintln!("SERVING {}", binding.local_addr());

    let incoming = incoming_completion.wait().map_err(display)?;
    let server_stream = incoming
        .accept_bi_available(STREAM_REF_CHUNK_SIZE)
        .run_with(Sink::head())
        .map_err(display)?
        .wait()
        .map_err(display)?;
    let source_ref = Source::from_iter(0_u64..count)
        .run_with(StreamRefs::source_ref())
        .map_err(display)?;
    let handle = serve_source_ref_over_quic(
        server_stream,
        source_ref,
        STREAM_REF_ID,
        StreamRefSettings::default(),
    )
    .map_err(display)?;
    handle.wait().map_err(display)?;
    incoming.connection().close(b"streamref complete");
    Ok(())
}

fn run_client(addr: String, count: u64, slow_ms: u64) -> Result<(), String> {
    let (_server_config, client_config) = quic_configs()?;
    let addr = addr
        .parse::<SocketAddr>()
        .map_err(|error| format!("invalid --connect addr: {error}"))?;
    let client_connection =
        TokioQuic::connect(addr, "localhost", client_config, STREAM_REF_CHUNK_SIZE)
            .run_with(Sink::head())
            .map_err(display)?
            .wait()
            .map_err(display)?;
    let client_stream = client_connection
        .open_bi_stream_available(STREAM_REF_CHUNK_SIZE)
        .run_with(Sink::head())
        .map_err(display)?
        .wait()
        .map_err(display)?;
    let (mut remote_source, handle) =
        source_ref_over_quic::<u64>(client_stream, STREAM_REF_ID, StreamRefSettings::default());
    if slow_ms > 0 {
        remote_source = remote_source.map(move |item| {
            thread::sleep(Duration::from_millis(slow_ms));
            item
        });
    }
    let values = remote_source.run_collect().map_err(display)?;
    handle.wait().map_err(display)?;

    let received = values.len() as u64;
    let checksum = values.iter().copied().sum::<u64>();
    let expected_checksum = (0_u64..count).sum::<u64>();
    if received != count || checksum != expected_checksum {
        return Err(format!(
            "mismatch: received {received}/{count}, checksum {checksum}/{expected_checksum}"
        ));
    }
    println!("RECEIVED {received} CHECKSUM {checksum}");
    client_connection.close(b"streamref complete");
    Ok(())
}

/// SinkRef server: hosts a SinkRef receiver, collects inbound elements, and
/// prints `RECEIVED <n> CHECKSUM <sum>`.
///
/// The receiver (consumer) opens the bidi stream so its handshake+demand
/// establish the QUIC stream before the sender accepts.
fn run_sink_server(addr: String, count: u64) -> Result<(), String> {
    let (server_config, _client_config) = quic_configs()?;
    let (binding_completion, incoming_completion) =
        TokioQuic::bind(addr, server_config, STREAM_REF_CHUNK_SIZE)
            .to_mat(Sink::head(), Keep::both)
            .run()
            .map_err(display)?;
    let binding = binding_completion.wait().map_err(display)?;
    eprintln!("SERVING --sink {}", binding.local_addr());

    let incoming = incoming_completion.wait().map_err(display)?;
    // Receiver opens the bidi stream (consumer sends handshake+demand first).
    let server_stream = incoming
        .open_bi_stream_available(STREAM_REF_CHUNK_SIZE)
        .run_with(Sink::head())
        .map_err(display)?
        .wait()
        .map_err(display)?;
    let (remote_source, handle) =
        serve_sink_ref_over_quic::<u64>(server_stream, STREAM_REF_ID, StreamRefSettings::default());
    let values = remote_source.run_collect().map_err(display)?;
    handle.wait().map_err(display)?;

    let received = values.len() as u64;
    let checksum = values.iter().copied().sum::<u64>();
    let expected_checksum = (0_u64..count).sum::<u64>();
    if received != count || checksum != expected_checksum {
        return Err(format!(
            "mismatch: received {received}/{count}, checksum {checksum}/{expected_checksum}"
        ));
    }
    println!("RECEIVED {received} CHECKSUM {checksum}");
    incoming.connection().close(b"streamref complete");
    Ok(())
}

/// SinkRef client: accepts a server-opened bidi stream and sends 0..N into the
/// remote SinkRef.
fn run_sink_client(addr: String, count: u64) -> Result<(), String> {
    let (_server_config, client_config) = quic_configs()?;
    let addr = addr
        .parse::<SocketAddr>()
        .map_err(|error| format!("invalid --connect addr: {error}"))?;
    let client_connection =
        TokioQuic::connect(addr, "localhost", client_config, STREAM_REF_CHUNK_SIZE)
            .run_with(Sink::head())
            .map_err(display)?
            .wait()
            .map_err(display)?;
    // Sender accepts the server-opened bidi stream (producer waits for demand).
    let client_stream = client_connection
        .accept_bi_available(STREAM_REF_CHUNK_SIZE)
        .run_with(Sink::head())
        .map_err(display)?
        .wait()
        .map_err(display)?;
    let (send_sink, handle) =
        sink_ref_over_quic::<u64>(client_stream, STREAM_REF_ID, StreamRefSettings::default());
    let send_completion = Source::from_iter(0_u64..count)
        .run_with(send_sink)
        .map_err(display)?;
    send_completion.wait().map_err(display)?;
    handle.wait().map_err(display)?;
    eprintln!("SENT {count}");
    client_connection.close(b"streamref complete");
    Ok(())
}

fn parse_args() -> Result<Mode, String> {
    let mut args = env::args().skip(1);
    let mut serve = None::<String>;
    let mut connect = None::<String>;
    let mut count = None::<u64>;
    let mut slow_ms = 0_u64;
    let mut sink = false;

    while let Some(arg) = args.next() {
        match arg.as_str() {
            "--serve" => serve = Some(next_arg(&mut args, "--serve")?),
            "--connect" => connect = Some(next_arg(&mut args, "--connect")?),
            "--count" => {
                count = Some(
                    next_arg(&mut args, "--count")?
                        .parse()
                        .map_err(|error| format!("invalid --count: {error}"))?,
                );
            }
            "--slow-ms" => {
                slow_ms = next_arg(&mut args, "--slow-ms")?
                    .parse()
                    .map_err(|error| format!("invalid --slow-ms: {error}"))?;
            }
            "--sink" => sink = true,
            "--help" | "-h" => return Err(usage()),
            other => return Err(format!("unknown argument {other}\n{}", usage())),
        }
    }

    let count = count.ok_or_else(usage)?;
    match (serve, connect) {
        (Some(addr), None) => Ok(Mode::Serve { addr, count, sink }),
        (None, Some(addr)) => Ok(Mode::Connect {
            addr,
            count,
            slow_ms,
            sink,
        }),
        _ => Err(usage()),
    }
}

fn next_arg(args: &mut impl Iterator<Item = String>, flag: &str) -> Result<String, String> {
    args.next()
        .ok_or_else(|| format!("{flag} requires a value\n{}", usage()))
}

fn usage() -> String {
    "usage: streamref_quic_node --serve <addr> --count <N> [--sink]\n\
     \x20      streamref_quic_node --connect <addr> --count <N> [--slow-ms <d>] [--sink]"
        .to_owned()
}

fn quic_configs() -> Result<(quinn::ServerConfig, quinn::ClientConfig), String> {
    let CertifiedKey { cert, key_pair } =
        generate_simple_self_signed(["localhost".to_owned()]).map_err(display)?;
    let cert_der: CertificateDer<'static> = cert.der().clone();
    let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
    let server_crypto = RustlsServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(vec![cert_der.clone()], key_der)
        .map_err(display)?;

    let client_crypto = RustlsClientConfig::builder()
        .dangerous()
        .with_custom_certificate_verifier(SkipServerVerification::new())
        .with_no_client_auth();

    Ok((
        quinn::ServerConfig::with_crypto(Arc::new(
            QuicServerConfig::try_from(server_crypto).map_err(display)?,
        )),
        quinn::ClientConfig::new(Arc::new(
            QuicClientConfig::try_from(client_crypto).map_err(display)?,
        )),
    ))
}

fn display(error: impl std::fmt::Display) -> String {
    error.to_string()
}

#[derive(Debug)]
struct SkipServerVerification(Arc<datum_net::quic::rustls::crypto::CryptoProvider>);

impl SkipServerVerification {
    fn new() -> Arc<Self> {
        Arc::new(Self(Arc::new(
            datum_net::quic::rustls::crypto::aws_lc_rs::default_provider(),
        )))
    }
}

impl datum_net::quic::rustls::client::danger::ServerCertVerifier for SkipServerVerification {
    fn verify_server_cert(
        &self,
        _end_entity: &CertificateDer<'_>,
        _intermediates: &[CertificateDer<'_>],
        _server_name: &ServerName<'_>,
        _ocsp: &[u8],
        _now: UnixTime,
    ) -> Result<
        datum_net::quic::rustls::client::danger::ServerCertVerified,
        datum_net::quic::rustls::Error,
    > {
        Ok(datum_net::quic::rustls::client::danger::ServerCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        message: &[u8],
        cert: &CertificateDer<'_>,
        dss: &datum_net::quic::rustls::DigitallySignedStruct,
    ) -> Result<
        datum_net::quic::rustls::client::danger::HandshakeSignatureValid,
        datum_net::quic::rustls::Error,
    > {
        datum_net::quic::rustls::crypto::verify_tls12_signature(
            message,
            cert,
            dss,
            &self.0.signature_verification_algorithms,
        )
    }

    fn verify_tls13_signature(
        &self,
        message: &[u8],
        cert: &CertificateDer<'_>,
        dss: &datum_net::quic::rustls::DigitallySignedStruct,
    ) -> Result<
        datum_net::quic::rustls::client::danger::HandshakeSignatureValid,
        datum_net::quic::rustls::Error,
    > {
        datum_net::quic::rustls::crypto::verify_tls13_signature(
            message,
            cert,
            dss,
            &self.0.signature_verification_algorithms,
        )
    }

    fn supported_verify_schemes(&self) -> Vec<datum_net::quic::rustls::SignatureScheme> {
        self.0.signature_verification_algorithms.supported_schemes()
    }
}