mod support;
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;
use wolfssl::{
Certificate, PrivateKey, RootCertStore, TlsAcceptor, TlsClient, TlsClientConfig,
TlsServerConfig,
};
use support::{
client_config, server_config, start_echo_server, CA_CERT_PEM, SERVER_CERT_PEM, SERVER_KEY_PEM,
};
#[test]
fn full_client_server_round_trip() {
let (port, server_handle) = start_echo_server(server_config(false), 1);
let cfg = client_config(false);
let stream = TcpStream::connect(format!("127.0.0.1:{port}")).unwrap();
let mut tls = TlsClient::new(cfg, "localhost", stream).expect("handshake failed");
let msg = b"round-trip test message 9876543210";
tls.write_all(msg).unwrap();
let mut buf = vec![0u8; msg.len()];
tls.read_exact(&mut buf).unwrap();
assert_eq!(&buf, msg);
let msg2 = b"second message ABCDEF";
tls.write_all(msg2).unwrap();
let mut buf2 = vec![0u8; msg2.len()];
tls.read_exact(&mut buf2).unwrap();
assert_eq!(&buf2, msg2);
drop(tls);
server_handle.join().unwrap();
}
#[test]
fn mtls_both_sides_authenticated() {
let (port, server_handle) = start_echo_server(server_config(true), 1);
let cfg = client_config(true);
let stream = TcpStream::connect(format!("127.0.0.1:{port}")).unwrap();
let mut tls = TlsClient::new(cfg, "localhost", stream).expect("mTLS handshake should succeed");
let msg = b"mutual auth verified data 0xDEADBEEF";
tls.write_all(msg).unwrap();
let mut buf = vec![0u8; msg.len()];
tls.read_exact(&mut buf).unwrap();
assert_eq!(&buf, msg);
drop(tls);
server_handle.join().unwrap();
}
#[test]
fn mtls_rejection_client_without_cert() {
let cert = Certificate::from_pem(SERVER_CERT_PEM);
let key = PrivateKey::from_pem(SERVER_KEY_PEM);
let mut ca_store = RootCertStore::new();
ca_store.add_pem(CA_CERT_PEM);
let srv_config = TlsServerConfig::builder()
.with_protocol_versions(&[wolfssl::ProtocolVersion::Tls12])
.with_certificate_chain(cert, key)
.with_client_auth(ca_store)
.build()
.unwrap();
let (port, server_handle) = start_echo_server(srv_config, 1);
let mut root_store = RootCertStore::new();
root_store.add_pem(CA_CERT_PEM);
let cfg = TlsClientConfig::builder()
.with_protocol_versions(&[wolfssl::ProtocolVersion::Tls12])
.with_root_certificates(root_store)
.with_no_client_auth()
.build()
.unwrap();
let stream = TcpStream::connect(format!("127.0.0.1:{port}")).unwrap();
let result = TlsClient::new(cfg, "localhost", stream);
match result {
Err(_) => {} Ok(mut tls) => {
let write_result = tls.write_all(b"test");
let mut buf = [0u8; 4];
let read_result = tls.read_exact(&mut buf);
assert!(
write_result.is_err() || read_result.is_err(),
"I/O must fail when server rejected client without cert"
);
}
}
let _ = server_handle.join();
}
#[test]
fn multiple_sequential_connections() {
let (port, server_handle) = start_echo_server(server_config(false), 3);
for i in 0..3 {
let cfg = client_config(false);
let stream = TcpStream::connect(format!("127.0.0.1:{port}")).unwrap();
let mut tls = TlsClient::new(cfg, "localhost", stream).expect("handshake failed");
let msg = format!("sequential connection {i} unique data");
tls.write_all(msg.as_bytes()).unwrap();
let mut buf = vec![0u8; msg.len()];
tls.read_exact(&mut buf).unwrap();
assert_eq!(buf, msg.as_bytes(), "data mismatch on connection {i}");
drop(tls);
}
server_handle.join().unwrap();
}
#[test]
fn concurrent_connections() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
let config = server_config(false);
let server_handle = thread::spawn(move || {
let acceptor = TlsAcceptor::new(config);
let mut handlers = Vec::new();
for _ in 0..3 {
let (stream, _) = listener.accept().unwrap();
match acceptor.accept(stream) {
Ok(mut tls) => {
let h = thread::spawn(move || {
let mut buf = [0u8; 256];
match tls.read(&mut buf) {
Ok(n) if n > 0 => {
tls.write_all(&buf[..n]).ok();
}
_ => {}
}
});
handlers.push(h);
}
Err(e) => eprintln!("server accept error: {e}"),
}
}
for h in handlers {
h.join().unwrap();
}
});
let mut client_handles = Vec::new();
for i in 0..3 {
let h = thread::spawn(move || {
let cfg = client_config(false);
let stream = TcpStream::connect(format!("127.0.0.1:{port}")).unwrap();
let mut tls = TlsClient::new(cfg, "localhost", stream).expect("handshake failed");
let msg = format!("concurrent client {i} unique payload");
tls.write_all(msg.as_bytes()).unwrap();
let mut buf = vec![0u8; msg.len()];
tls.read_exact(&mut buf).unwrap();
assert_eq!(
buf,
msg.as_bytes(),
"data mismatch for concurrent client {i}"
);
});
client_handles.push(h);
}
for h in client_handles {
h.join().unwrap();
}
server_handle.join().unwrap();
}