#![cfg(feature = "tls")]
use std::net::TcpListener;
use std::thread;
use std::time::Duration;
use nexus_web::tls::TlsConfig;
use nexus_web::ws::{Client, CloseCode, Message};
fn generate_oversize_ecdsa_chain() -> (Vec<rustls::pki_types::CertificateDer<'static>>, Vec<u8>) {
use rcgen::{BasicConstraints, CertificateParams, IsCa, KeyPair};
const CHAIN_DEPTH: usize = 10;
let mut keys: Vec<KeyPair> = Vec::with_capacity(CHAIN_DEPTH);
let mut certs: Vec<rcgen::Certificate> = Vec::with_capacity(CHAIN_DEPTH);
let root_key = KeyPair::generate().expect("root key");
let mut root_params = CertificateParams::new(Vec::<String>::new()).expect("root params");
root_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
let root_cert = root_params.self_signed(&root_key).expect("root self-sign");
keys.push(root_key);
certs.push(root_cert);
for _ in 0..(CHAIN_DEPTH - 2) {
let key = KeyPair::generate().expect("int key");
let mut params = CertificateParams::new(Vec::<String>::new()).expect("int params");
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
let parent_cert = certs.last().expect("parent");
let parent_key = keys.last().expect("parent key");
let cert = params
.signed_by(&key, parent_cert, parent_key)
.expect("int signed");
keys.push(key);
certs.push(cert);
}
let leaf_key = KeyPair::generate().expect("leaf key");
let leaf_params = CertificateParams::new(vec!["localhost".to_string()]).expect("leaf params");
let parent_cert = certs.last().expect("parent");
let parent_key = keys.last().expect("parent key");
let leaf_cert = leaf_params
.signed_by(&leaf_key, parent_cert, parent_key)
.expect("leaf signed");
let mut chain: Vec<rustls::pki_types::CertificateDer<'static>> =
Vec::with_capacity(CHAIN_DEPTH);
chain.push(rustls::pki_types::CertificateDer::from(
leaf_cert.der().to_vec(),
));
for cert in certs.iter().rev() {
chain.push(rustls::pki_types::CertificateDer::from(cert.der().to_vec()));
}
(chain, leaf_key.serialize_der())
}
#[allow(clippy::needless_pass_by_value)]
fn run_echo_server(listener: TcpListener, server_config: std::sync::Arc<rustls::ServerConfig>) {
let (tcp, _addr) = listener.accept().expect("accept");
tcp.set_nodelay(true).ok();
tcp.set_read_timeout(Some(Duration::from_secs(10))).ok();
tcp.set_write_timeout(Some(Duration::from_secs(10))).ok();
let server_conn = rustls::ServerConnection::new(server_config).expect("server conn");
let tls_stream = rustls::StreamOwned::new(server_conn, tcp);
let mut ws = Client::accept(tls_stream).expect("server WS accept");
while let Some(msg) = ws.recv().expect("server recv") {
match msg {
Message::Text(s) => {
let owned = s.to_string();
ws.send_text(&owned).expect("server send text");
}
Message::Binary(b) => {
let owned = b.to_vec();
ws.send_binary(&owned).expect("server send binary");
}
Message::Ping(payload) => {
let owned = payload.to_vec();
ws.send_pong(&owned).expect("server pong");
}
Message::Pong(_) => {}
Message::Close(_) => break,
}
}
}
fn smoke_check_simple_cert() {
let cert_kp =
rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).expect("simple cert");
let chain = vec![rustls::pki_types::CertificateDer::from(
cert_kp.cert.der().to_vec(),
)];
let key = rustls::pki_types::PrivateKeyDer::try_from(cert_kp.key_pair.serialize_der())
.expect("simple key");
let server_config = std::sync::Arc::new(
rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(chain, key)
.expect("smoke server config"),
);
let listener = TcpListener::bind("127.0.0.1:0").expect("smoke bind");
let port = listener.local_addr().expect("smoke local_addr").port();
let server_handle = thread::spawn(move || run_echo_server(listener, server_config));
let tls_config = TlsConfig::builder()
.danger_no_verify()
.build()
.expect("smoke tls");
let mut ws = nexus_web::ws::ClientBuilder::new()
.tls(&tls_config)
.connect_timeout(Duration::from_secs(10))
.connect(&format!("wss://127.0.0.1:{port}/"))
.expect("smoke WSS connect");
ws.send_text("smoke").expect("smoke send");
match ws.recv().expect("smoke recv").expect("smoke close") {
Message::Text(s) => assert_eq!(s, "smoke"),
other => panic!("smoke: expected Text, got {other:?}"),
}
ws.close(CloseCode::Normal, "done").expect("smoke close");
server_handle.join().expect("smoke server join");
}
#[test]
fn local_wss_echo_with_oversize_handshake_burst() {
smoke_check_simple_cert();
let (chain, key_der) = generate_oversize_ecdsa_chain();
let key = rustls::pki_types::PrivateKeyDer::try_from(key_der).expect("server key");
let server_config = std::sync::Arc::new(
rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(chain, key)
.expect("server config"),
);
let listener = TcpListener::bind("127.0.0.1:0").expect("bind 127.0.0.1:0");
let port = listener.local_addr().expect("local_addr").port();
let server_handle = thread::spawn(move || run_echo_server(listener, server_config));
let tls_config = TlsConfig::builder()
.danger_no_verify()
.build()
.expect("client tls config");
let mut ws = nexus_web::ws::ClientBuilder::new()
.tls(&tls_config)
.write_buffer_capacity(64 * 1024)
.connect_timeout(Duration::from_secs(10))
.connect(&format!("wss://127.0.0.1:{port}/"))
.expect("client WSS connect + upgrade");
let probe = "hello-from-#200-regression-test";
ws.send_text(probe).expect("client send");
match ws
.recv()
.expect("client recv")
.expect("close before message")
{
Message::Text(s) => assert_eq!(s, probe, "echo must match"),
other => panic!("expected Text echo, got {other:?}"),
}
let big = "x".repeat(8192);
ws.send_text(&big).expect("client send big");
match ws
.recv()
.expect("client recv big")
.expect("close before message")
{
Message::Text(s) => assert_eq!(s.len(), 8192, "big echo length must match"),
other => panic!("expected Text echo, got {other:?}"),
}
ws.close(CloseCode::Normal, "done").expect("client close");
server_handle.join().expect("server thread");
}