use datum::{Keep, Sink, Source, StreamRefId, StreamRefSettings, StreamRefs};
use datum_net::quic::{
crypto::rustls::{QuicClientConfig, QuicServerConfig},
quinn,
rustls::{
ClientConfig as RustlsClientConfig, ServerConfig as RustlsServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName, UnixTime},
},
};
use datum_net::{
TokioQuic, serve_sink_ref_over_quic, serve_source_ref_over_quic, sink_ref_over_quic,
source_ref_over_quic,
};
use rcgen::{CertifiedKey, generate_simple_self_signed};
use std::{env, net::SocketAddr, process, sync::Arc, thread, time::Duration};
const STREAM_REF_ID: StreamRefId = StreamRefId::from_u128(1);
const STREAM_REF_CHUNK_SIZE: usize = 8192;
enum Mode {
Serve {
addr: String,
count: u64,
sink: bool,
},
Connect {
addr: String,
count: u64,
slow_ms: u64,
sink: bool,
},
}
fn main() {
if let Err(error) = run() {
eprintln!("ERROR {error}");
process::exit(1);
}
}
fn run() -> Result<(), String> {
match parse_args()? {
Mode::Serve { addr, count, sink } => {
if sink {
run_sink_server(addr, count)
} else {
run_server(addr, count)
}
}
Mode::Connect {
addr,
count,
slow_ms,
sink,
} => {
if sink {
run_sink_client(addr, count)
} else {
run_client(addr, count, slow_ms)
}
}
}
}
fn run_server(addr: String, count: u64) -> Result<(), String> {
let (server_config, _client_config) = quic_configs()?;
let (binding_completion, incoming_completion) =
TokioQuic::bind(addr, server_config, STREAM_REF_CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.map_err(display)?;
let binding = binding_completion.wait().map_err(display)?;
eprintln!("SERVING {}", binding.local_addr());
let incoming = incoming_completion.wait().map_err(display)?;
let server_stream = incoming
.accept_bi_available(STREAM_REF_CHUNK_SIZE)
.run_with(Sink::head())
.map_err(display)?
.wait()
.map_err(display)?;
let source_ref = Source::from_iter(0_u64..count)
.run_with(StreamRefs::source_ref())
.map_err(display)?;
let handle = serve_source_ref_over_quic(
server_stream,
source_ref,
STREAM_REF_ID,
StreamRefSettings::default(),
)
.map_err(display)?;
handle.wait().map_err(display)?;
incoming.connection().close(b"streamref complete");
Ok(())
}
fn run_client(addr: String, count: u64, slow_ms: u64) -> Result<(), String> {
let (_server_config, client_config) = quic_configs()?;
let addr = addr
.parse::<SocketAddr>()
.map_err(|error| format!("invalid --connect addr: {error}"))?;
let client_connection =
TokioQuic::connect(addr, "localhost", client_config, STREAM_REF_CHUNK_SIZE)
.run_with(Sink::head())
.map_err(display)?
.wait()
.map_err(display)?;
let client_stream = client_connection
.open_bi_stream_available(STREAM_REF_CHUNK_SIZE)
.run_with(Sink::head())
.map_err(display)?
.wait()
.map_err(display)?;
let (mut remote_source, handle) =
source_ref_over_quic::<u64>(client_stream, STREAM_REF_ID, StreamRefSettings::default());
if slow_ms > 0 {
remote_source = remote_source.map(move |item| {
thread::sleep(Duration::from_millis(slow_ms));
item
});
}
let values = remote_source.run_collect().map_err(display)?;
handle.wait().map_err(display)?;
let received = values.len() as u64;
let checksum = values.iter().copied().sum::<u64>();
let expected_checksum = (0_u64..count).sum::<u64>();
if received != count || checksum != expected_checksum {
return Err(format!(
"mismatch: received {received}/{count}, checksum {checksum}/{expected_checksum}"
));
}
println!("RECEIVED {received} CHECKSUM {checksum}");
client_connection.close(b"streamref complete");
Ok(())
}
fn run_sink_server(addr: String, count: u64) -> Result<(), String> {
let (server_config, _client_config) = quic_configs()?;
let (binding_completion, incoming_completion) =
TokioQuic::bind(addr, server_config, STREAM_REF_CHUNK_SIZE)
.to_mat(Sink::head(), Keep::both)
.run()
.map_err(display)?;
let binding = binding_completion.wait().map_err(display)?;
eprintln!("SERVING --sink {}", binding.local_addr());
let incoming = incoming_completion.wait().map_err(display)?;
let server_stream = incoming
.open_bi_stream_available(STREAM_REF_CHUNK_SIZE)
.run_with(Sink::head())
.map_err(display)?
.wait()
.map_err(display)?;
let (remote_source, handle) =
serve_sink_ref_over_quic::<u64>(server_stream, STREAM_REF_ID, StreamRefSettings::default());
let values = remote_source.run_collect().map_err(display)?;
handle.wait().map_err(display)?;
let received = values.len() as u64;
let checksum = values.iter().copied().sum::<u64>();
let expected_checksum = (0_u64..count).sum::<u64>();
if received != count || checksum != expected_checksum {
return Err(format!(
"mismatch: received {received}/{count}, checksum {checksum}/{expected_checksum}"
));
}
println!("RECEIVED {received} CHECKSUM {checksum}");
incoming.connection().close(b"streamref complete");
Ok(())
}
fn run_sink_client(addr: String, count: u64) -> Result<(), String> {
let (_server_config, client_config) = quic_configs()?;
let addr = addr
.parse::<SocketAddr>()
.map_err(|error| format!("invalid --connect addr: {error}"))?;
let client_connection =
TokioQuic::connect(addr, "localhost", client_config, STREAM_REF_CHUNK_SIZE)
.run_with(Sink::head())
.map_err(display)?
.wait()
.map_err(display)?;
let client_stream = client_connection
.accept_bi_available(STREAM_REF_CHUNK_SIZE)
.run_with(Sink::head())
.map_err(display)?
.wait()
.map_err(display)?;
let (send_sink, handle) =
sink_ref_over_quic::<u64>(client_stream, STREAM_REF_ID, StreamRefSettings::default());
let send_completion = Source::from_iter(0_u64..count)
.run_with(send_sink)
.map_err(display)?;
send_completion.wait().map_err(display)?;
handle.wait().map_err(display)?;
eprintln!("SENT {count}");
client_connection.close(b"streamref complete");
Ok(())
}
fn parse_args() -> Result<Mode, String> {
let mut args = env::args().skip(1);
let mut serve = None::<String>;
let mut connect = None::<String>;
let mut count = None::<u64>;
let mut slow_ms = 0_u64;
let mut sink = false;
while let Some(arg) = args.next() {
match arg.as_str() {
"--serve" => serve = Some(next_arg(&mut args, "--serve")?),
"--connect" => connect = Some(next_arg(&mut args, "--connect")?),
"--count" => {
count = Some(
next_arg(&mut args, "--count")?
.parse()
.map_err(|error| format!("invalid --count: {error}"))?,
);
}
"--slow-ms" => {
slow_ms = next_arg(&mut args, "--slow-ms")?
.parse()
.map_err(|error| format!("invalid --slow-ms: {error}"))?;
}
"--sink" => sink = true,
"--help" | "-h" => return Err(usage()),
other => return Err(format!("unknown argument {other}\n{}", usage())),
}
}
let count = count.ok_or_else(usage)?;
match (serve, connect) {
(Some(addr), None) => Ok(Mode::Serve { addr, count, sink }),
(None, Some(addr)) => Ok(Mode::Connect {
addr,
count,
slow_ms,
sink,
}),
_ => Err(usage()),
}
}
fn next_arg(args: &mut impl Iterator<Item = String>, flag: &str) -> Result<String, String> {
args.next()
.ok_or_else(|| format!("{flag} requires a value\n{}", usage()))
}
fn usage() -> String {
"usage: streamref_quic_node --serve <addr> --count <N> [--sink]\n\
\x20 streamref_quic_node --connect <addr> --count <N> [--slow-ms <d>] [--sink]"
.to_owned()
}
fn quic_configs() -> Result<(quinn::ServerConfig, quinn::ClientConfig), String> {
let CertifiedKey { cert, key_pair } =
generate_simple_self_signed(["localhost".to_owned()]).map_err(display)?;
let cert_der: CertificateDer<'static> = cert.der().clone();
let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
let server_crypto = RustlsServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert_der.clone()], key_der)
.map_err(display)?;
let client_crypto = RustlsClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(SkipServerVerification::new())
.with_no_client_auth();
Ok((
quinn::ServerConfig::with_crypto(Arc::new(
QuicServerConfig::try_from(server_crypto).map_err(display)?,
)),
quinn::ClientConfig::new(Arc::new(
QuicClientConfig::try_from(client_crypto).map_err(display)?,
)),
))
}
fn display(error: impl std::fmt::Display) -> String {
error.to_string()
}
#[derive(Debug)]
struct SkipServerVerification(Arc<datum_net::quic::rustls::crypto::CryptoProvider>);
impl SkipServerVerification {
fn new() -> Arc<Self> {
Arc::new(Self(Arc::new(
datum_net::quic::rustls::crypto::aws_lc_rs::default_provider(),
)))
}
}
impl datum_net::quic::rustls::client::danger::ServerCertVerifier for SkipServerVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp: &[u8],
_now: UnixTime,
) -> Result<
datum_net::quic::rustls::client::danger::ServerCertVerified,
datum_net::quic::rustls::Error,
> {
Ok(datum_net::quic::rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &datum_net::quic::rustls::DigitallySignedStruct,
) -> Result<
datum_net::quic::rustls::client::danger::HandshakeSignatureValid,
datum_net::quic::rustls::Error,
> {
datum_net::quic::rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &datum_net::quic::rustls::DigitallySignedStruct,
) -> Result<
datum_net::quic::rustls::client::danger::HandshakeSignatureValid,
datum_net::quic::rustls::Error,
> {
datum_net::quic::rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<datum_net::quic::rustls::SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}